Skip to content

Commit 8268157

Browse files
committed
Merge branch 'master' into sample_arg_rename
2 parents cc9edc1 + d47e133 commit 8268157

File tree

6 files changed

+42
-21
lines changed

6 files changed

+42
-21
lines changed

RELEASE-NOTES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
of the success probability. This is faster and more stable than using
99
`p=tt.nnet.sigmoid(logit_p)`.
1010

11+
### Fixes
12+
13+
- `VonMises` does not overflow for large values of kappa. i0 and i1 have been removed and we now use
14+
log_i0 to compute the logp.
15+
1116
### Deprecations
1217

1318
- DIC and BPIC calculations have been removed

pymc3/backends/report.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import namedtuple
22
import logging
33
import enum
4+
from ..util import is_transformed_name, get_untransformed_name
45

56

67
logger = logging.getLogger('pymc3')
@@ -69,10 +70,15 @@ def _run_convergence_checks(self, trace, model):
6970

7071
from pymc3 import diagnostics
7172

72-
varname = [rv.name for rv in model.free_RVs]
73+
varnames = []
74+
for rv in model.free_RVs:
75+
rv_name = rv.name
76+
if is_transformed_name(rv_name):
77+
rv_name = get_untransformed_name(rv_name)
78+
varnames.append(rv_name)
7379

74-
self._effective_n = effective_n = diagnostics.effective_n(trace, varname)
75-
self._gelman_rubin = gelman_rubin = diagnostics.gelman_rubin(trace, varname)
80+
self._effective_n = effective_n = diagnostics.effective_n(trace, varnames)
81+
self._gelman_rubin = gelman_rubin = diagnostics.gelman_rubin(trace, varnames)
7682

7783
warnings = []
7884
rhat_max = max(val.max() for val in gelman_rubin.values())

pymc3/distributions/continuous.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pymc3.theanof import floatX
1717
from . import transforms
1818
from pymc3.util import get_variable_name
19-
from .special import i0, i1
19+
from .special import log_i0
2020
from .dist_math import bound, logpow, gammaln, betaln, std_cdf, alltrue_elemwise, SplineWrapper
2121
from .distribution import Continuous, draw_values, generate_samples
2222

@@ -1850,8 +1850,7 @@ def __init__(self, mu=0.0, kappa=None, transform='circular',
18501850
transform = transforms.Circular()
18511851
super(VonMises, self).__init__(transform=transform, *args, **kwargs)
18521852
self.mean = self.median = self.mode = self.mu = mu = tt.as_tensor_variable(mu)
1853-
self.kappa = kappa = tt.as_tensor_variable(kappa)
1854-
self.variance = 1 - i1(kappa) / i0(kappa)
1853+
self.kappa = kappa = floatX(tt.as_tensor_variable(kappa))
18551854

18561855
assert_negative_support(kappa, 'kappa', 'VonMises')
18571856

@@ -1865,7 +1864,8 @@ def random(self, point=None, size=None, repeat=None):
18651864
def logp(self, value):
18661865
mu = self.mu
18671866
kappa = self.kappa
1868-
return bound(kappa * tt.cos(mu - value) - tt.log(2 * np.pi * i0(kappa)), value >= -np.pi, value <= np.pi, kappa >= 0)
1867+
return bound(kappa * tt.cos(mu - value) - (tt.log(2 * np.pi) + log_i0(kappa)),
1868+
kappa > 0, value >= -np.pi, value <= np.pi)
18691869

18701870
def _repr_latex_(self, name=None, dist=None):
18711871
if dist is None:

pymc3/distributions/special.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import numpy as np
22
import theano.tensor as tt
3-
from theano.scalar.basic_scipy import GammaLn, Psi, I0, I1
3+
from theano.scalar.basic_scipy import GammaLn, Psi
44
from theano import scalar
55

6-
__all__ = ['gammaln', 'multigammaln', 'psi', 'i0', 'i1']
6+
__all__ = ['gammaln', 'multigammaln', 'psi', 'log_i0']
77

88
scalar_gammaln = GammaLn(scalar.upgrade_to_float, name='scalar_gammaln')
99
gammaln = tt.Elemwise(scalar_gammaln, name='gammaln')
@@ -12,20 +12,28 @@
1212
def multigammaln(a, p):
1313
"""Multivariate Log Gamma
1414
15-
:Parameters:
16-
a : tensor like
17-
p : int degrees of freedom
18-
p > 0
15+
Parameters
16+
----------
17+
a : tensor like
18+
p : int
19+
degrees of freedom. p > 0
1920
"""
2021
i = tt.arange(1, p + 1)
2122
return (p * (p - 1) * tt.log(np.pi) / 4.
2223
+ tt.sum(gammaln(a + (1. - i) / 2.), axis=0))
2324

24-
scalar_psi = Psi(scalar.upgrade_to_float, name='scalar_psi')
25-
psi = tt.Elemwise(scalar_psi, name='psi')
2625

27-
scalar_i0 = I0(scalar.upgrade_to_float, name='scalar_i0')
28-
i0 = tt.Elemwise(scalar_i0, name='i0')
26+
def log_i0(x):
27+
"""
28+
Calculates the logarithm of the 0 order modified Bessel function of the first kind""
29+
"""
30+
return tt.switch(tt.lt(x, 5), tt.log1p(x**2. / 4. + x**4. / 64. + x**6. / 2304.
31+
+ x**8. / 147456. + x**10. / 14745600.
32+
+ x**12. / 2123366400.),
33+
x - 0.5 * tt.log(2. * np.pi * x) + tt.log1p(1. / (8. * x)
34+
+ 9. / (128. * x**2.) + 225. / (3072. * x**3.)
35+
+ 11025. / (98304. * x**4.)))
2936

30-
scalar_i1 = I1(scalar.upgrade_to_float, name='scalar_i1')
31-
i1 = tt.Elemwise(scalar_i1, name='i1')
37+
38+
scalar_psi = Psi(scalar.upgrade_to_float, name='scalar_psi')
39+
psi = tt.Elemwise(scalar_psi, name='psi')

pymc3/sampling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,9 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
258258
tune : int
259259
Number of iterations to tune, if applicable (defaults to 500).
260260
Samplers adjust the step sizes, scalings or similar during
261-
tuning. These samples will be drawn in addition to samples
262-
and discarded unless discard_tuned_samples is set to True.
261+
tuning. Tuning samples will be drawn in addition to the number
262+
specified in the `draws` argument, and will be discarded
263+
unless `discard_tuned_samples` is set to False.
263264
nuts_kwargs : dict
264265
Options for the NUTS sampler. See the docstring of NUTS
265266
for a complete list of options. Common options are

pymc3/tests/test_distributions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,7 @@ def test_ex_gaussian(self, value, mu, sigma, nu, logp):
931931
pt = {'eg': value}
932932
assert_almost_equal(model.fastlogp(pt), logp, decimal=select_by_precision(float64=6, float32=2), err_msg=str(pt))
933933

934+
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
934935
def test_vonmises(self):
935936
self.pymc3_matches_scipy(
936937
VonMises, R, {'mu': Circ, 'kappa': Rplus},

0 commit comments

Comments
 (0)