Skip to content

Commit cf13e77

Browse files
committed
Bring back univariate continuous logp methods
Also add parameter check to uniform logcdf
1 parent 42fd461 commit cf13e77

File tree

2 files changed

+182
-5
lines changed

2 files changed

+182
-5
lines changed

pymc/distributions/continuous.py

Lines changed: 178 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# Contains code from Aeppl, Copyright (c) 2021-2022, Aesara Developers.
16+
1517
# coding: utf-8
1618
"""
1719
A collection of common probability distributions for stochastic
@@ -26,7 +28,7 @@
2628
import aesara.tensor as at
2729
import numpy as np
2830

29-
from aeppl.logprob import _logprob, logcdf
31+
from aeppl.logprob import _logprob, logcdf, logprob
3032
from aesara.graph.basic import Apply, Variable
3133
from aesara.graph.op import Op
3234
from aesara.raise_op import Assert
@@ -311,9 +313,22 @@ def moment(rv, size, lower, upper):
311313
moment = at.full(size, moment)
312314
return moment
313315

316+
def logp(value, lower, upper):
317+
res = at.switch(
318+
at.bitwise_and(at.ge(value, lower), at.le(value, upper)),
319+
at.fill(value, -at.log(upper - lower)),
320+
-np.inf,
321+
)
322+
323+
return check_parameters(
324+
res,
325+
lower <= upper,
326+
msg="lower <= upper",
327+
)
328+
314329
def logcdf(value, lower, upper):
315-
return at.switch(
316-
at.lt(value, lower) | at.lt(upper, lower),
330+
res = at.switch(
331+
at.lt(value, lower),
317332
-np.inf,
318333
at.switch(
319334
at.lt(value, upper),
@@ -322,6 +337,12 @@ def logcdf(value, lower, upper):
322337
),
323338
)
324339

340+
return check_parameters(
341+
res,
342+
lower <= upper,
343+
msg="lower <= upper",
344+
)
345+
325346

326347
@_default_transform.register(Uniform)
327348
def uniform_default_transform(op, rv):
@@ -495,6 +516,14 @@ def moment(rv, size, mu, sigma):
495516
mu = at.full(size, mu)
496517
return mu
497518

519+
def logp(value, mu, sigma):
520+
res = -0.5 * at.pow((value - mu) / sigma, 2) - at.log(at.sqrt(2.0 * np.pi)) - at.log(sigma)
521+
return check_parameters(
522+
res,
523+
sigma > 0,
524+
msg="sigma > 0",
525+
)
526+
498527
def logcdf(value, mu, sigma):
499528
return check_parameters(
500529
normal_lcdf(mu, sigma, value),
@@ -780,6 +809,15 @@ def moment(rv, size, loc, sigma):
780809
moment = at.full(size, moment)
781810
return moment
782811

812+
def logp(value, loc, sigma):
813+
res = -0.5 * at.pow((value - loc) / sigma, 2) + at.log(at.sqrt(2.0 / np.pi)) - at.log(sigma)
814+
res = at.switch(at.ge(value, loc), res, -np.inf)
815+
return check_parameters(
816+
res,
817+
sigma > 0,
818+
msg="sigma > 0",
819+
)
820+
783821
def logcdf(value, loc, sigma):
784822
z = zvalue(value, mu=loc, sigma=sigma)
785823
logcdf = at.switch(
@@ -1079,6 +1117,20 @@ def get_alpha_beta(self, alpha=None, beta=None, mu=None, sigma=None):
10791117

10801118
return alpha, beta
10811119

1120+
def logp(value, alpha, beta):
1121+
res = (
1122+
at.switch(at.eq(alpha, 1.0), 0.0, (alpha - 1.0) * at.log(value))
1123+
+ at.switch(at.eq(beta, 1.0), 0.0, (beta - 1.0) * at.log1p(-value))
1124+
- (at.gammaln(alpha) + at.gammaln(beta) - at.gammaln(alpha + beta))
1125+
)
1126+
res = at.switch(at.bitwise_and(at.ge(value, 0.0), at.le(value, 1.0)), res, -np.inf)
1127+
return check_parameters(
1128+
res,
1129+
alpha > 0,
1130+
beta > 0,
1131+
msg="alpha > 0, beta > 0",
1132+
)
1133+
10821134
def logcdf(value, alpha, beta):
10831135
logcdf = at.switch(
10841136
at.lt(value, 0),
@@ -1261,6 +1313,15 @@ def moment(rv, size, mu):
12611313
mu = at.full(size, mu)
12621314
return mu
12631315

1316+
def logp(value, mu):
1317+
res = -at.log(mu) - value / mu
1318+
res = at.switch(at.ge(value, 0.0), res, -np.inf)
1319+
return check_parameters(
1320+
res,
1321+
mu >= 0,
1322+
msg="mu >= 0",
1323+
)
1324+
12641325
def logcdf(value, mu):
12651326
lam = at.reciprocal(mu)
12661327
res = at.switch(
@@ -1334,6 +1395,14 @@ def moment(rv, size, mu, b):
13341395
mu = at.full(size, mu)
13351396
return mu
13361397

1398+
def logp(value, mu, b):
1399+
res = -at.log(2 * b) - at.abs(value - mu) / b
1400+
return check_parameters(
1401+
res,
1402+
b > 0,
1403+
msg="b > 0",
1404+
)
1405+
13371406
def logcdf(value, mu, b):
13381407
y = (value - mu) / b
13391408

@@ -1524,6 +1593,20 @@ def moment(rv, size, mu, sigma):
15241593
mean = at.full(size, mean)
15251594
return mean
15261595

1596+
def logp(value, mu, sigma):
1597+
res = (
1598+
-0.5 * at.pow((at.log(value) - mu) / sigma, 2)
1599+
- 0.5 * at.log(2.0 * np.pi)
1600+
- at.log(sigma)
1601+
- at.log(value)
1602+
)
1603+
res = at.switch(at.gt(value, 0.0), res, -np.inf)
1604+
return check_parameters(
1605+
res,
1606+
sigma > 0,
1607+
msg="sigma > 0",
1608+
)
1609+
15271610
def logcdf(value, mu, sigma):
15281611
res = at.switch(
15291612
at.le(value, 0),
@@ -1732,6 +1815,16 @@ def moment(rv, size, alpha, m):
17321815
median = at.full(size, median)
17331816
return median
17341817

1818+
def logp(value, alpha, m):
1819+
res = at.log(alpha) + logpow(m, alpha) - logpow(value, alpha + 1.0)
1820+
res = at.switch(at.ge(value, m), res, -np.inf)
1821+
return check_parameters(
1822+
res,
1823+
alpha > 0,
1824+
m > 0,
1825+
msg="alpha > 0, m > 0",
1826+
)
1827+
17351828
def logcdf(value, alpha, m):
17361829
arg = (m / value) ** alpha
17371830

@@ -1819,6 +1912,14 @@ def moment(rv, size, alpha, beta):
18191912
alpha = at.full(size, alpha)
18201913
return alpha
18211914

1915+
def logp(value, alpha, beta):
1916+
res = -at.log(np.pi) - at.log(beta) - at.log1p(at.pow((value - alpha) / beta, 2))
1917+
return check_parameters(
1918+
res,
1919+
beta > 0,
1920+
msg="beta > 0",
1921+
)
1922+
18221923
def logcdf(value, alpha, beta):
18231924
res = at.log(0.5 + at.arctan((value - alpha) / beta) / np.pi)
18241925
return check_parameters(
@@ -1879,6 +1980,15 @@ def moment(rv, size, loc, beta):
18791980
beta = at.full(size, beta)
18801981
return beta
18811982

1983+
def logp(value, loc, beta):
1984+
res = at.log(2) + logprob(Cauchy.dist(loc, beta), value)
1985+
res = at.switch(at.ge(value, loc), res, -np.inf)
1986+
return check_parameters(
1987+
res,
1988+
beta > 0,
1989+
msg="beta > 0",
1990+
)
1991+
18821992
def logcdf(value, loc, beta):
18831993
res = at.switch(
18841994
at.lt(value, loc),
@@ -1990,6 +2100,17 @@ def moment(rv, size, alpha, inv_beta):
19902100
mean = at.full(size, mean)
19912101
return mean
19922102

2103+
def logp(value, alpha, inv_beta):
2104+
beta = at.reciprocal(inv_beta)
2105+
res = -at.gammaln(alpha) + logpow(beta, alpha) - beta * value + logpow(value, alpha - 1)
2106+
res = at.switch(at.ge(value, 0.0), res, -np.inf)
2107+
return check_parameters(
2108+
res,
2109+
alpha > 0,
2110+
beta > 0,
2111+
msg="alpha > 0, beta > 0",
2112+
)
2113+
19932114
def logcdf(value, alpha, inv_beta):
19942115
beta = at.reciprocal(inv_beta)
19952116
res = at.switch(
@@ -2091,6 +2212,16 @@ def _get_alpha_beta(cls, alpha, beta, mu, sigma):
20912212

20922213
return alpha, beta
20932214

2215+
def logp(value, alpha, beta):
2216+
res = -at.gammaln(alpha) + logpow(beta, alpha) - beta / value + logpow(value, -alpha - 1)
2217+
res = at.switch(at.ge(value, 0.0), res, -np.inf)
2218+
return check_parameters(
2219+
res,
2220+
alpha > 0,
2221+
beta > 0,
2222+
msg="alpha > 0, beta > 0",
2223+
)
2224+
20942225
def logcdf(value, alpha, beta):
20952226
res = at.switch(
20962227
at.lt(value, 0),
@@ -2158,6 +2289,9 @@ def moment(rv, size, nu):
21582289
moment = at.full(size, moment)
21592290
return moment
21602291

2292+
def logp(value, nu):
2293+
return logprob(Gamma.dist(alpha=nu / 2, beta=0.5), value)
2294+
21612295
def logcdf(value, nu):
21622296
return logcdf(Gamma.dist(alpha=nu / 2, beta=0.5), value)
21632297

@@ -2586,6 +2720,15 @@ def moment(rv, size, mu, kappa):
25862720
mu = at.full(size, mu)
25872721
return mu
25882722

2723+
def logp(value, mu, kappa):
2724+
res = kappa * at.cos(mu - value) - at.log(2 * np.pi) - at.log(at.i0(kappa))
2725+
res = at.switch(at.bitwise_and(at.ge(value, -np.pi), at.le(value, np.pi)), res, -np.inf)
2726+
return check_parameters(
2727+
res,
2728+
kappa > 0,
2729+
msg="kappa > 0",
2730+
)
2731+
25892732

25902733
class SkewNormalRV(RandomVariable):
25912734
name = "skewnormal"
@@ -2771,6 +2914,20 @@ def moment(rv, size, lower, c, upper):
27712914
mean = at.full(size, mean)
27722915
return mean
27732916

2917+
def logp(value, lower, c, upper):
2918+
res = at.switch(
2919+
at.lt(value, c),
2920+
at.log(2 * (value - lower) / ((upper - lower) * (c - lower))),
2921+
at.log(2 * (upper - value) / ((upper - lower) * (upper - c))),
2922+
)
2923+
res = at.switch(at.bitwise_and(at.le(lower, value), at.le(value, upper)), res, -np.inf)
2924+
return check_parameters(
2925+
res,
2926+
lower <= c,
2927+
c <= upper,
2928+
msg="lower <= c <= upper",
2929+
)
2930+
27742931
def logcdf(value, lower, c, upper):
27752932
res = at.switch(
27762933
at.le(value, lower),
@@ -2863,6 +3020,15 @@ def moment(rv, size, mu, beta):
28633020
mean = at.full(size, mean)
28643021
return mean
28653022

3023+
def logp(value, mu, beta):
3024+
z = (value - mu) / beta
3025+
res = -z - at.exp(-z) - at.log(beta)
3026+
return check_parameters(
3027+
res,
3028+
beta > 0,
3029+
msg="beta > 0",
3030+
)
3031+
28663032
def logcdf(value, mu, beta):
28673033
res = -at.exp(-(value - mu) / beta)
28683034

@@ -3062,6 +3228,15 @@ def moment(rv, size, mu, s):
30623228
mu = at.full(size, mu)
30633229
return mu
30643230

3231+
def logp(value, mu, s):
3232+
z = (value - mu) / s
3233+
res = -z - at.log(s) - 2.0 * at.log1p(at.exp(-z))
3234+
return check_parameters(
3235+
res,
3236+
s > 0,
3237+
msg="s > 0",
3238+
)
3239+
30653240
def logcdf(value, mu, s):
30663241
res = -at.log1pexp(-(value - mu) / s)
30673242

pymc/tests/distributions/test_continuous.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,10 @@ def test_uniform(self):
185185
# Custom logp / logcdf check for invalid parameters
186186
invalid_dist = pm.Uniform.dist(lower=1, upper=0)
187187
with aesara.config.change_flags(mode=Mode("py")):
188-
assert logp(invalid_dist, np.array(0.5)).eval() == -np.inf
189-
assert logcdf(invalid_dist, np.array(2.0)).eval() == -np.inf
188+
with pytest.raises(ParameterValueError):
189+
logp(invalid_dist, np.array(0.5)).eval()
190+
with pytest.raises(ParameterValueError):
191+
logcdf(invalid_dist, np.array(0.5)).eval()
190192

191193
def test_triangular(self):
192194
check_logp(

0 commit comments

Comments
 (0)