Skip to content

Commit 42fd461

Browse files
committed
Standardize check_parameters comparison order with message
1 parent 79f590b commit 42fd461

File tree

3 files changed

+189
-56
lines changed

3 files changed

+189
-56
lines changed

pymc/distributions/continuous.py

Lines changed: 99 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def moment(rv, size, mu, sigma):
498498
def logcdf(value, mu, sigma):
499499
return check_parameters(
500500
normal_lcdf(mu, sigma, value),
501-
0 < sigma,
501+
sigma > 0,
502502
msg="sigma > 0",
503503
)
504504

@@ -790,7 +790,7 @@ def logcdf(value, loc, sigma):
790790

791791
return check_parameters(
792792
logcdf,
793-
0 < sigma,
793+
sigma > 0,
794794
msg="sigma > 0",
795795
)
796796

@@ -965,7 +965,11 @@ def logcdf(value, mu, lam, alpha):
965965
)
966966

967967
return check_parameters(
968-
logcdf, 0 < mu, 0 < lam, 0 <= alpha, msg="mu > 0, lam > 0, alpha >= 0"
968+
logcdf,
969+
mu > 0,
970+
lam > 0,
971+
alpha >= 0,
972+
msg="mu > 0, lam > 0, alpha >= 0",
969973
)
970974

971975

@@ -1088,8 +1092,8 @@ def logcdf(value, alpha, beta):
10881092

10891093
return check_parameters(
10901094
logcdf,
1091-
0 < alpha,
1092-
0 < beta,
1095+
alpha > 0,
1096+
beta > 0,
10931097
msg="alpha > 0, beta > 0",
10941098
)
10951099

@@ -1265,7 +1269,11 @@ def logcdf(value, mu):
12651269
at.log1mexp(-lam * value),
12661270
)
12671271

1268-
return check_parameters(res, 0 <= lam, msg="lam >= 0")
1272+
return check_parameters(
1273+
res,
1274+
lam >= 0,
1275+
msg="lam >= 0",
1276+
)
12691277

12701278

12711279
class Laplace(Continuous):
@@ -1341,7 +1349,7 @@ def logcdf(value, mu, b):
13411349

13421350
return check_parameters(
13431351
res,
1344-
0 < b,
1352+
b > 0,
13451353
msg="b > 0",
13461354
)
13471355

@@ -1423,7 +1431,12 @@ def logp(value, b, kappa, mu):
14231431
-value * b * at.sgn(value) * (kappa ** at.sgn(value))
14241432
)
14251433

1426-
return check_parameters(res, 0 < b, 0 < kappa, msg="b > 0, kappa > 0")
1434+
return check_parameters(
1435+
res,
1436+
b > 0,
1437+
kappa > 0,
1438+
msg="b > 0, kappa > 0",
1439+
)
14271440

14281441

14291442
class LogNormal(PositiveContinuous):
@@ -1518,7 +1531,11 @@ def logcdf(value, mu, sigma):
15181531
normal_lcdf(mu, sigma, at.log(value)),
15191532
)
15201533

1521-
return check_parameters(res, 0 < sigma, msg="sigma > 0")
1534+
return check_parameters(
1535+
res,
1536+
sigma > 0,
1537+
msg="sigma > 0",
1538+
)
15221539

15231540

15241541
Lognormal = LogNormal
@@ -1629,7 +1646,12 @@ def logp(value, nu, mu, sigma):
16291646
- (nu + 1.0) / 2.0 * at.log1p(lam * (value - mu) ** 2 / nu)
16301647
)
16311648

1632-
return check_parameters(res, lam > 0, nu > 0, msg="lam > 0, nu > 0")
1649+
return check_parameters(
1650+
res,
1651+
lam > 0,
1652+
nu > 0,
1653+
msg="lam > 0, nu > 0",
1654+
)
16331655

16341656
def logcdf(value, nu, mu, sigma):
16351657
_, sigma = get_tau_sigma(sigma=sigma)
@@ -1640,7 +1662,12 @@ def logcdf(value, nu, mu, sigma):
16401662

16411663
res = at.log(at.betainc(nu / 2.0, nu / 2.0, x))
16421664

1643-
return check_parameters(res, 0 < nu, 0 < sigma, msg="nu > 0, sigma > 0")
1665+
return check_parameters(
1666+
res,
1667+
nu > 0,
1668+
sigma > 0,
1669+
msg="nu > 0, sigma > 0",
1670+
)
16441671

16451672

16461673
class Pareto(BoundedContinuous):
@@ -1718,7 +1745,12 @@ def logcdf(value, alpha, m):
17181745
),
17191746
)
17201747

1721-
return check_parameters(res, 0 < alpha, 0 < m, msg="alpha > 0, m > 0")
1748+
return check_parameters(
1749+
res,
1750+
alpha > 0,
1751+
m > 0,
1752+
msg="alpha > 0, m > 0",
1753+
)
17221754

17231755

17241756
@_default_transform.register(Pareto)
@@ -1791,7 +1823,7 @@ def logcdf(value, alpha, beta):
17911823
res = at.log(0.5 + at.arctan((value - alpha) / beta) / np.pi)
17921824
return check_parameters(
17931825
res,
1794-
0 < beta,
1826+
beta > 0,
17951827
msg="beta > 0",
17961828
)
17971829

@@ -1854,7 +1886,11 @@ def logcdf(value, loc, beta):
18541886
at.log(2 * at.arctan((value - loc) / beta) / np.pi),
18551887
)
18561888

1857-
return check_parameters(res, 0 < beta, msg="beta > 0")
1889+
return check_parameters(
1890+
res,
1891+
beta > 0,
1892+
msg="beta > 0",
1893+
)
18581894

18591895

18601896
class Gamma(PositiveContinuous):
@@ -2062,7 +2098,12 @@ def logcdf(value, alpha, beta):
20622098
at.log(at.gammaincc(alpha, beta / value)),
20632099
)
20642100

2065-
return check_parameters(res, 0 < alpha, 0 < beta, msg="alpha > 0, beta > 0")
2101+
return check_parameters(
2102+
res,
2103+
alpha > 0,
2104+
beta > 0,
2105+
msg="alpha > 0, beta > 0",
2106+
)
20662107

20672108

20682109
class ChiSquared(PositiveContinuous):
@@ -2210,7 +2251,12 @@ def logcdf(value, alpha, beta):
22102251
at.log1mexp(-a),
22112252
)
22122253

2213-
return check_parameters(res, 0 < alpha, 0 < beta, msg="alpha > 0, beta > 0")
2254+
return check_parameters(
2255+
res,
2256+
alpha > 0,
2257+
beta > 0,
2258+
msg="alpha > 0, beta > 0",
2259+
)
22142260

22152261
def logp(value, alpha, beta):
22162262
res = (
@@ -2220,7 +2266,12 @@ def logp(value, alpha, beta):
22202266
- at.pow(value / beta, alpha)
22212267
)
22222268
res = at.switch(at.ge(value, 0.0), res, -np.inf)
2223-
return check_parameters(res, 0 < alpha, 0 < beta, msg="alpha > 0, beta > 0")
2269+
return check_parameters(
2270+
res,
2271+
alpha > 0,
2272+
beta > 0,
2273+
msg="alpha > 0, beta > 0",
2274+
)
22242275

22252276

22262277
class HalfStudentTRV(RandomVariable):
@@ -2327,7 +2378,12 @@ def logp(value, nu, sigma):
23272378
res,
23282379
)
23292380

2330-
return check_parameters(res, sigma > 0, nu > 0, msg="sigma > 0, nu > 0")
2381+
return check_parameters(
2382+
res,
2383+
sigma > 0,
2384+
nu > 0,
2385+
msg="sigma > 0, nu > 0",
2386+
)
23312387

23322388

23332389
class ExGaussianRV(RandomVariable):
@@ -2442,8 +2498,8 @@ def logp(value, mu, sigma, nu):
24422498
)
24432499
return check_parameters(
24442500
res,
2445-
0 < sigma,
2446-
0 < nu,
2501+
sigma > 0,
2502+
nu > 0,
24472503
msg="nu > 0, sigma > 0",
24482504
)
24492505

@@ -2462,7 +2518,12 @@ def logcdf(value, mu, sigma, nu):
24622518
normal_lcdf(mu, sigma, value),
24632519
)
24642520

2465-
return check_parameters(res, 0 < sigma, 0 < nu, msg="sigma > 0, nu > 0")
2521+
return check_parameters(
2522+
res,
2523+
sigma > 0,
2524+
nu > 0,
2525+
msg="sigma > 0, nu > 0",
2526+
)
24662527

24672528

24682529
class VonMises(CircularContinuous):
@@ -2630,7 +2691,11 @@ def logp(value, mu, sigma, alpha):
26302691
+ (-tau * (value - mu) ** 2 + at.log(tau / np.pi / 2.0)) / 2.0
26312692
)
26322693

2633-
return check_parameters(res, tau > 0, msg="tau > 0")
2694+
return check_parameters(
2695+
res,
2696+
tau > 0,
2697+
msg="tau > 0",
2698+
)
26342699

26352700

26362701
class Triangular(BoundedContinuous):
@@ -2801,7 +2866,11 @@ def moment(rv, size, mu, beta):
28012866
def logcdf(value, mu, beta):
28022867
res = -at.exp(-(value - mu) / beta)
28032868

2804-
return check_parameters(res, 0 < beta, msg="beta > 0")
2869+
return check_parameters(
2870+
res,
2871+
beta > 0,
2872+
msg="beta > 0",
2873+
)
28052874

28062875

28072876
class RiceRV(RandomVariable):
@@ -2998,7 +3067,7 @@ def logcdf(value, mu, s):
29983067

29993068
return check_parameters(
30003069
res,
3001-
0 < s,
3070+
s > 0,
30023071
msg="s > 0",
30033072
)
30043073

@@ -3327,14 +3396,18 @@ def moment(rv, size, mu, sigma):
33273396
def logp(value, mu, sigma):
33283397
scaled = (value - mu) / sigma
33293398
res = -(1 / 2) * (scaled + at.exp(-scaled)) - at.log(sigma) - (1 / 2) * at.log(2 * np.pi)
3330-
return check_parameters(res, 0 < sigma, msg="sigma > 0")
3399+
return check_parameters(
3400+
res,
3401+
sigma > 0,
3402+
msg="sigma > 0",
3403+
)
33313404

33323405
def logcdf(value, mu, sigma):
33333406
scaled = (value - mu) / sigma
33343407
res = at.log(at.erfc(at.exp(-scaled / 2) * (2**-0.5)))
33353408
return check_parameters(
33363409
res,
3337-
0 < sigma,
3410+
sigma > 0,
33383411
msg="sigma > 0",
33393412
)
33403413

0 commit comments

Comments
 (0)