12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ # Contains code from Aeppl, Copyright (c) 2021-2022, Aesara Developers.
16
+
15
17
# coding: utf-8
16
18
"""
17
19
A collection of common probability distributions for stochastic
26
28
import aesara .tensor as at
27
29
import numpy as np
28
30
29
- from aeppl .logprob import _logprob , logcdf
31
+ from aeppl .logprob import _logprob , logcdf , logprob
30
32
from aesara .graph .basic import Apply , Variable
31
33
from aesara .graph .op import Op
32
34
from aesara .raise_op import Assert
@@ -311,9 +313,22 @@ def moment(rv, size, lower, upper):
311
313
moment = at .full (size , moment )
312
314
return moment
313
315
316
+ def logp (value , lower , upper ):
317
+ res = at .switch (
318
+ at .bitwise_and (at .ge (value , lower ), at .le (value , upper )),
319
+ at .fill (value , - at .log (upper - lower )),
320
+ - np .inf ,
321
+ )
322
+
323
+ return check_parameters (
324
+ res ,
325
+ lower <= upper ,
326
+ msg = "lower <= upper" ,
327
+ )
328
+
314
329
def logcdf (value , lower , upper ):
315
- return at .switch (
316
- at .lt (value , lower ) | at . lt ( upper , lower ) ,
330
+ res = at .switch (
331
+ at .lt (value , lower ),
317
332
- np .inf ,
318
333
at .switch (
319
334
at .lt (value , upper ),
@@ -322,6 +337,12 @@ def logcdf(value, lower, upper):
322
337
),
323
338
)
324
339
340
+ return check_parameters (
341
+ res ,
342
+ lower <= upper ,
343
+ msg = "lower <= upper" ,
344
+ )
345
+
325
346
326
347
@_default_transform .register (Uniform )
327
348
def uniform_default_transform (op , rv ):
@@ -495,6 +516,14 @@ def moment(rv, size, mu, sigma):
495
516
mu = at .full (size , mu )
496
517
return mu
497
518
519
+ def logp (value , mu , sigma ):
520
+ res = - 0.5 * at .pow ((value - mu ) / sigma , 2 ) - at .log (at .sqrt (2.0 * np .pi )) - at .log (sigma )
521
+ return check_parameters (
522
+ res ,
523
+ sigma > 0 ,
524
+ msg = "sigma > 0" ,
525
+ )
526
+
498
527
def logcdf (value , mu , sigma ):
499
528
return check_parameters (
500
529
normal_lcdf (mu , sigma , value ),
@@ -780,6 +809,15 @@ def moment(rv, size, loc, sigma):
780
809
moment = at .full (size , moment )
781
810
return moment
782
811
812
+ def logp (value , loc , sigma ):
813
+ res = - 0.5 * at .pow ((value - loc ) / sigma , 2 ) + at .log (at .sqrt (2.0 / np .pi )) - at .log (sigma )
814
+ res = at .switch (at .ge (value , loc ), res , - np .inf )
815
+ return check_parameters (
816
+ res ,
817
+ sigma > 0 ,
818
+ msg = "sigma > 0" ,
819
+ )
820
+
783
821
def logcdf (value , loc , sigma ):
784
822
z = zvalue (value , mu = loc , sigma = sigma )
785
823
logcdf = at .switch (
@@ -1079,6 +1117,20 @@ def get_alpha_beta(self, alpha=None, beta=None, mu=None, sigma=None):
1079
1117
1080
1118
return alpha , beta
1081
1119
1120
+ def logp (value , alpha , beta ):
1121
+ res = (
1122
+ at .switch (at .eq (alpha , 1.0 ), 0.0 , (alpha - 1.0 ) * at .log (value ))
1123
+ + at .switch (at .eq (beta , 1.0 ), 0.0 , (beta - 1.0 ) * at .log1p (- value ))
1124
+ - (at .gammaln (alpha ) + at .gammaln (beta ) - at .gammaln (alpha + beta ))
1125
+ )
1126
+ res = at .switch (at .bitwise_and (at .ge (value , 0.0 ), at .le (value , 1.0 )), res , - np .inf )
1127
+ return check_parameters (
1128
+ res ,
1129
+ alpha > 0 ,
1130
+ beta > 0 ,
1131
+ msg = "alpha > 0, beta > 0" ,
1132
+ )
1133
+
1082
1134
def logcdf (value , alpha , beta ):
1083
1135
logcdf = at .switch (
1084
1136
at .lt (value , 0 ),
@@ -1261,6 +1313,15 @@ def moment(rv, size, mu):
1261
1313
mu = at .full (size , mu )
1262
1314
return mu
1263
1315
1316
+ def logp (value , mu ):
1317
+ res = - at .log (mu ) - value / mu
1318
+ res = at .switch (at .ge (value , 0.0 ), res , - np .inf )
1319
+ return check_parameters (
1320
+ res ,
1321
+ mu >= 0 ,
1322
+ msg = "mu >= 0" ,
1323
+ )
1324
+
1264
1325
def logcdf (value , mu ):
1265
1326
lam = at .reciprocal (mu )
1266
1327
res = at .switch (
@@ -1334,6 +1395,14 @@ def moment(rv, size, mu, b):
1334
1395
mu = at .full (size , mu )
1335
1396
return mu
1336
1397
1398
+ def logp (value , mu , b ):
1399
+ res = - at .log (2 * b ) - at .abs (value - mu ) / b
1400
+ return check_parameters (
1401
+ res ,
1402
+ b > 0 ,
1403
+ msg = "b > 0" ,
1404
+ )
1405
+
1337
1406
def logcdf (value , mu , b ):
1338
1407
y = (value - mu ) / b
1339
1408
@@ -1524,6 +1593,20 @@ def moment(rv, size, mu, sigma):
1524
1593
mean = at .full (size , mean )
1525
1594
return mean
1526
1595
1596
+ def logp (value , mu , sigma ):
1597
+ res = (
1598
+ - 0.5 * at .pow ((at .log (value ) - mu ) / sigma , 2 )
1599
+ - 0.5 * at .log (2.0 * np .pi )
1600
+ - at .log (sigma )
1601
+ - at .log (value )
1602
+ )
1603
+ res = at .switch (at .gt (value , 0.0 ), res , - np .inf )
1604
+ return check_parameters (
1605
+ res ,
1606
+ sigma > 0 ,
1607
+ msg = "sigma > 0" ,
1608
+ )
1609
+
1527
1610
def logcdf (value , mu , sigma ):
1528
1611
res = at .switch (
1529
1612
at .le (value , 0 ),
@@ -1732,6 +1815,16 @@ def moment(rv, size, alpha, m):
1732
1815
median = at .full (size , median )
1733
1816
return median
1734
1817
1818
+ def logp (value , alpha , m ):
1819
+ res = at .log (alpha ) + logpow (m , alpha ) - logpow (value , alpha + 1.0 )
1820
+ res = at .switch (at .ge (value , m ), res , - np .inf )
1821
+ return check_parameters (
1822
+ res ,
1823
+ alpha > 0 ,
1824
+ m > 0 ,
1825
+ msg = "alpha > 0, m > 0" ,
1826
+ )
1827
+
1735
1828
def logcdf (value , alpha , m ):
1736
1829
arg = (m / value ) ** alpha
1737
1830
@@ -1819,6 +1912,14 @@ def moment(rv, size, alpha, beta):
1819
1912
alpha = at .full (size , alpha )
1820
1913
return alpha
1821
1914
1915
+ def logp (value , alpha , beta ):
1916
+ res = - at .log (np .pi ) - at .log (beta ) - at .log1p (at .pow ((value - alpha ) / beta , 2 ))
1917
+ return check_parameters (
1918
+ res ,
1919
+ beta > 0 ,
1920
+ msg = "beta > 0" ,
1921
+ )
1922
+
1822
1923
def logcdf (value , alpha , beta ):
1823
1924
res = at .log (0.5 + at .arctan ((value - alpha ) / beta ) / np .pi )
1824
1925
return check_parameters (
@@ -1879,6 +1980,15 @@ def moment(rv, size, loc, beta):
1879
1980
beta = at .full (size , beta )
1880
1981
return beta
1881
1982
1983
+ def logp (value , loc , beta ):
1984
+ res = at .log (2 ) + logprob (Cauchy .dist (loc , beta ), value )
1985
+ res = at .switch (at .ge (value , loc ), res , - np .inf )
1986
+ return check_parameters (
1987
+ res ,
1988
+ beta > 0 ,
1989
+ msg = "beta > 0" ,
1990
+ )
1991
+
1882
1992
def logcdf (value , loc , beta ):
1883
1993
res = at .switch (
1884
1994
at .lt (value , loc ),
@@ -1990,6 +2100,17 @@ def moment(rv, size, alpha, inv_beta):
1990
2100
mean = at .full (size , mean )
1991
2101
return mean
1992
2102
2103
+ def logp (value , alpha , inv_beta ):
2104
+ beta = at .reciprocal (inv_beta )
2105
+ res = - at .gammaln (alpha ) + logpow (beta , alpha ) - beta * value + logpow (value , alpha - 1 )
2106
+ res = at .switch (at .ge (value , 0.0 ), res , - np .inf )
2107
+ return check_parameters (
2108
+ res ,
2109
+ alpha > 0 ,
2110
+ beta > 0 ,
2111
+ msg = "alpha > 0, beta > 0" ,
2112
+ )
2113
+
1993
2114
def logcdf (value , alpha , inv_beta ):
1994
2115
beta = at .reciprocal (inv_beta )
1995
2116
res = at .switch (
@@ -2091,6 +2212,16 @@ def _get_alpha_beta(cls, alpha, beta, mu, sigma):
2091
2212
2092
2213
return alpha , beta
2093
2214
2215
+ def logp (value , alpha , beta ):
2216
+ res = - at .gammaln (alpha ) + logpow (beta , alpha ) - beta / value + logpow (value , - alpha - 1 )
2217
+ res = at .switch (at .ge (value , 0.0 ), res , - np .inf )
2218
+ return check_parameters (
2219
+ res ,
2220
+ alpha > 0 ,
2221
+ beta > 0 ,
2222
+ msg = "alpha > 0, beta > 0" ,
2223
+ )
2224
+
2094
2225
def logcdf (value , alpha , beta ):
2095
2226
res = at .switch (
2096
2227
at .lt (value , 0 ),
@@ -2158,6 +2289,9 @@ def moment(rv, size, nu):
2158
2289
moment = at .full (size , moment )
2159
2290
return moment
2160
2291
2292
+ def logp (value , nu ):
2293
+ return logprob (Gamma .dist (alpha = nu / 2 , beta = 0.5 ), value )
2294
+
2161
2295
def logcdf (value , nu ):
2162
2296
return logcdf (Gamma .dist (alpha = nu / 2 , beta = 0.5 ), value )
2163
2297
@@ -2586,6 +2720,15 @@ def moment(rv, size, mu, kappa):
2586
2720
mu = at .full (size , mu )
2587
2721
return mu
2588
2722
2723
+ def logp (value , mu , kappa ):
2724
+ res = kappa * at .cos (mu - value ) - at .log (2 * np .pi ) - at .log (at .i0 (kappa ))
2725
+ res = at .switch (at .bitwise_and (at .ge (value , - np .pi ), at .le (value , np .pi )), res , - np .inf )
2726
+ return check_parameters (
2727
+ res ,
2728
+ kappa > 0 ,
2729
+ msg = "kappa > 0" ,
2730
+ )
2731
+
2589
2732
2590
2733
class SkewNormalRV (RandomVariable ):
2591
2734
name = "skewnormal"
@@ -2771,6 +2914,20 @@ def moment(rv, size, lower, c, upper):
2771
2914
mean = at .full (size , mean )
2772
2915
return mean
2773
2916
2917
+ def logp (value , lower , c , upper ):
2918
+ res = at .switch (
2919
+ at .lt (value , c ),
2920
+ at .log (2 * (value - lower ) / ((upper - lower ) * (c - lower ))),
2921
+ at .log (2 * (upper - value ) / ((upper - lower ) * (upper - c ))),
2922
+ )
2923
+ res = at .switch (at .bitwise_and (at .le (lower , value ), at .le (value , upper )), res , - np .inf )
2924
+ return check_parameters (
2925
+ res ,
2926
+ lower <= c ,
2927
+ c <= upper ,
2928
+ msg = "lower <= c <= upper" ,
2929
+ )
2930
+
2774
2931
def logcdf (value , lower , c , upper ):
2775
2932
res = at .switch (
2776
2933
at .le (value , lower ),
@@ -2863,6 +3020,15 @@ def moment(rv, size, mu, beta):
2863
3020
mean = at .full (size , mean )
2864
3021
return mean
2865
3022
3023
+ def logp (value , mu , beta ):
3024
+ z = (value - mu ) / beta
3025
+ res = - z - at .exp (- z ) - at .log (beta )
3026
+ return check_parameters (
3027
+ res ,
3028
+ beta > 0 ,
3029
+ msg = "beta > 0" ,
3030
+ )
3031
+
2866
3032
def logcdf (value , mu , beta ):
2867
3033
res = - at .exp (- (value - mu ) / beta )
2868
3034
@@ -3062,6 +3228,15 @@ def moment(rv, size, mu, s):
3062
3228
mu = at .full (size , mu )
3063
3229
return mu
3064
3230
3231
+ def logp (value , mu , s ):
3232
+ z = (value - mu ) / s
3233
+ res = - z - at .log (s ) - 2.0 * at .log1p (at .exp (- z ))
3234
+ return check_parameters (
3235
+ res ,
3236
+ s > 0 ,
3237
+ msg = "s > 0" ,
3238
+ )
3239
+
3065
3240
def logcdf (value , mu , s ):
3066
3241
res = - at .log1pexp (- (value - mu ) / s )
3067
3242
0 commit comments