Skip to content

Commit 1dc1f01

Browse files
committed
Fix multivariate metropolis proposal
1 parent d5bc5fb commit 1dc1f01

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed

pymc3/step_methods/metropolis.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import numpy.random as nr
33
import theano
4+
import scipy.linalg
45

56
from ..distributions import draw_values
67
from .arraystep import ArrayStepShared, ArrayStep, metrop_select, Competence
@@ -41,8 +42,16 @@ def __call__(self):
4142

4243

4344
class MultivariateNormalProposal(Proposal):
45+
def __init__(self, s):
46+
n, m = s.shape
47+
if n != m:
48+
raise ValueError("Covariance matrix is not symmetric.")
49+
self.n = n
50+
self.chol = scipy.linalg.cholesky(s, lower=True)
51+
4452
def __call__(self, num_draws=None):
45-
return nr.multivariate_normal(mean=np.zeros(self.s.shape[0]), cov=self.s, size=num_draws)
53+
b = np.random.randn(self.n)
54+
return np.dot(self.chol, b)
4655

4756

4857
class Metropolis(ArrayStepShared):
@@ -76,7 +85,7 @@ class Metropolis(ArrayStepShared):
7685
'tune': np.bool,
7786
}]
7887

79-
def __init__(self, vars=None, S=None, proposal_dist=NormalProposal, scaling=1.,
88+
def __init__(self, vars=None, S=None, proposal_dist=None, scaling=1.,
8089
tune=True, tune_interval=100, model=None, mode=None, **kwargs):
8190

8291
model = pm.modelcontext(model)
@@ -87,7 +96,16 @@ def __init__(self, vars=None, S=None, proposal_dist=NormalProposal, scaling=1.,
8796

8897
if S is None:
8998
S = np.ones(sum(v.dsize for v in vars))
90-
self.proposal_dist = proposal_dist(S)
99+
100+
if proposal_dist is not None:
101+
self.proposal_dist = proposal_dist(S)
102+
elif S.ndim == 1:
103+
self.proposal_dist = NormalProposal(S)
104+
elif S.ndim == 2:
105+
self.proposal_dist = MultivariateNormalProposal(S)
106+
else:
107+
raise ValueError("Invalid rank for variance: %s" % S.ndim)
108+
91109
self.scaling = np.atleast_1d(scaling)
92110
self.tune = tune
93111
self.tune_interval = tune_interval

pymc3/tests/test_step.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
from pymc3.sampling import assign_step_methods, sample
77
from pymc3.model import Model
88
from pymc3.step_methods import (NUTS, BinaryGibbsMetropolis, CategoricalGibbsMetropolis,
9-
Metropolis, Slice, CompoundStep,
9+
Metropolis, Slice, CompoundStep, NormalProposal,
1010
MultivariateNormalProposal, HamiltonianMC)
1111
from pymc3.distributions import Binomial, Normal, Bernoulli, Categorical
1212

1313
from numpy.testing import assert_array_almost_equal
1414
import numpy as np
15+
import numpy.testing as npt
1516
from tqdm import tqdm
1617

1718

@@ -187,6 +188,29 @@ def test_step_categorical(self):
187188
yield self.check_stat, check, trace, step.__class__.__name__
188189

189190

191+
class TestMetropolisProposal(unittest.TestCase):
192+
def test_proposal_choice(self):
193+
_, model, _ = mv_simple()
194+
with model:
195+
s = np.ones(model.ndim)
196+
sampler = Metropolis(S=s)
197+
assert isinstance(sampler.proposal_dist, NormalProposal)
198+
s = np.diag(s)
199+
sampler = Metropolis(S=s)
200+
assert isinstance(sampler.proposal_dist, MultivariateNormalProposal)
201+
s[0, 0] = -s[0, 0]
202+
with self.assertRaises(np.linalg.LinAlgError):
203+
sampler = Metropolis(S=s)
204+
205+
def test_mv_proposal(self):
206+
np.random.seed(42)
207+
cov = np.random.randn(5, 5)
208+
cov = cov.dot(cov.T)
209+
prop = MultivariateNormalProposal(cov)
210+
samples = np.array([prop() for _ in range(10000)])
211+
npt.assert_allclose(np.cov(samples.T), cov, rtol=0.2)
212+
213+
190214
class TestCompoundStep(unittest.TestCase):
191215
samplers = (Metropolis, Slice, HamiltonianMC, NUTS)
192216

0 commit comments

Comments
 (0)