From a0e36d0df29b944ddc3327b113d69c348a39b0c8 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Fri, 14 May 2021 11:32:45 +0200 Subject: [PATCH 1/3] Fix BinaryMetropolis astep --- pymc3/step_methods/metropolis.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymc3/step_methods/metropolis.py b/pymc3/step_methods/metropolis.py index 6e226f153b..c9cb58470f 100644 --- a/pymc3/step_methods/metropolis.py +++ b/pymc3/step_methods/metropolis.py @@ -329,6 +329,7 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None): def astep(self, q0: RaveledVars, logp) -> Tuple[RaveledVars, List[Dict[str, Any]]]: + logp_q0 = logp(q0) point_map_info = q0.point_map_info q0 = q0.data @@ -340,8 +341,9 @@ def astep(self, q0: RaveledVars, logp) -> Tuple[RaveledVars, List[Dict[str, Any] # Locations where switches occur, according to p_jump switch_locs = rand_array < p_jump q[switch_locs] = True - q[switch_locs] + logp_q = logp(RaveledVars(q, point_map_info)) - accept = logp(q) - logp(q0) + accept = logp_q - logp_q0 q_new, accepted = metrop_select(accept, q, q0) self.accepted += accepted From 9b963eb8ba55048577e163f4cb4f179d51fe9417 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Fri, 14 May 2021 11:33:07 +0200 Subject: [PATCH 2/3] Fix _check_start_shape --- pymc3/sampling.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 3a70cc38db..475ea30a96 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -684,6 +684,11 @@ def sample( def _check_start_shape(model, start): if not isinstance(start, dict): raise TypeError("start argument must be a dict or an array-like of dicts") + + # Filter "non-input" variables + initial_point = model.initial_point + start = {k: v for k, v in start.items() if k in initial_point} + e = "" for var in model.basic_RVs: var_shape = model.fastfn(var.shape)(start) From 19dcc45ca59498f0935e0dd9feefb67858257485 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Fri, 14 May 2021 11:33:38 +0200 Subject: [PATCH 3/3] Update failing tests --- pymc3/tests/test_examples.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pymc3/tests/test_examples.py b/pymc3/tests/test_examples.py index ad54236543..66c7b2fbe1 100644 --- a/pymc3/tests/test_examples.py +++ b/pymc3/tests/test_examples.py @@ -194,7 +194,7 @@ def build_disaster_model(masked=False): @pytest.mark.xfail( - reason="_check_start_shape fails with start dictionary" + reason="Arviz summary fails" # condition=(aesara.config.floatX == "float32"), reason="Fails on float32" ) class TestDisasterModel(SeededTest): @@ -222,7 +222,6 @@ def test_disaster_model_missing(self): az.summary(tr) -@pytest.mark.xfail(reason="_check_start_shape fails with start dictionary") class TestLatentOccupancy(SeededTest): """ From the PyMC example list @@ -278,7 +277,7 @@ def test_run(self): "theta": np.array(5, dtype="f"), } step_one = pm.Metropolis([model["theta_interval__"], model["psi_logodds__"]]) - step_two = pm.BinaryMetropolis([model.z]) + step_two = pm.BinaryMetropolis([model.rvs_to_values[model["z"]]]) pm.sample(50, step=[step_one, step_two], start=start, chains=1)