Skip to content

Commit 1a604c3

Browse files
Re-enable Arviz tests in pymc3.tests.test_sampling
1 parent e01a473 commit 1a604c3

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

pymc3/tests/test_sampling.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@
2020

2121
import aesara
2222
import aesara.tensor as aet
23-
import arviz as az
2423
import numpy as np
2524
import numpy.testing as npt
2625
import pytest
2726

2827
from aesara import shared
28+
from arviz import InferenceData
29+
from arviz import from_dict as az_from_dict
2930
from scipy import stats
3031

3132
import pymc3 as pm
@@ -200,7 +201,7 @@ def test_return_inferencedata(self, monkeypatch):
200201

201202
# inferencedata with tuning
202203
result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=False)
203-
assert isinstance(result, az.InferenceData)
204+
assert isinstance(result, InferenceData)
204205
assert result.posterior.sizes["draw"] == 100
205206
assert result.posterior.sizes["chain"] == 2
206207
assert len(result._groups_warmup) > 0
@@ -215,7 +216,7 @@ def test_return_inferencedata(self, monkeypatch):
215216
random_seed=-1
216217
)
217218
assert "prior" in result
218-
assert isinstance(result, az.InferenceData)
219+
assert isinstance(result, InferenceData)
219220
assert result.posterior.sizes["draw"] == 100
220221
assert result.posterior.sizes["chain"] == 2
221222
assert len(result._groups_warmup) == 0
@@ -458,20 +459,26 @@ def test_normal_scalar(self):
458459
ppc = pm.sample_posterior_predictive(trace, size=5, var_names=["a"])
459460
assert ppc["a"].shape == (nchains * ndraws, 5)
460461

461-
@pytest.mark.xfail(reason="Arviz not refactored for v4")
462462
def test_normal_scalar_idata(self):
463463
nchains = 2
464464
ndraws = 500
465465
with pm.Model() as model:
466466
mu = pm.Normal("mu", 0.0, 1.0)
467467
a = pm.Normal("a", mu=mu, sigma=1, observed=0.0)
468468
trace = pm.sample(
469-
draws=ndraws, chains=nchains, return_inferencedata=True, discard_tuned_samples=False
469+
draws=ndraws,
470+
chains=nchains,
471+
return_inferencedata=False,
472+
discard_tuned_samples=False,
470473
)
471474

475+
assert not isinstance(trace, InferenceData)
476+
472477
with model:
473478
# test keep_size parameter and idata input
474479
idata = pm.to_inference_data(trace)
480+
assert isinstance(idata, InferenceData)
481+
475482
ppc = pm.sample_posterior_predictive(idata, keep_size=True)
476483
assert ppc["a"].shape == (nchains, ndraws)
477484

@@ -505,16 +512,19 @@ def test_normal_vector(self, caplog):
505512
assert "a" in ppc
506513
assert ppc["a"].shape == (10, 4, 2)
507514

508-
@pytest.mark.xfail(reason="Arviz not refactored for v4")
509515
def test_normal_vector_idata(self, caplog):
510516
with pm.Model() as model:
511517
mu = pm.Normal("mu", 0.0, 1.0)
512518
a = pm.Normal("a", mu=mu, sigma=1, observed=np.array([0.5, 0.2]))
513519
trace = pm.sample(return_inferencedata=False)
514520

521+
assert not isinstance(trace, InferenceData)
522+
515523
with model:
516524
# test keep_size parameter with inference data as input...
517525
idata = pm.to_inference_data(trace)
526+
assert isinstance(idata, InferenceData)
527+
518528
ppc = pm.sample_posterior_predictive(idata, keep_size=True)
519529
assert ppc["a"].shape == (trace.nchains, len(trace), 2)
520530

@@ -703,7 +713,7 @@ def test_potentials_warning(self):
703713
p = pm.Potential("p", a + 1)
704714
obs = pm.Normal("obs", a, 1, observed=5)
705715

706-
trace = az.from_dict({"a": np.random.rand(10)})
716+
trace = az_from_dict({"a": np.random.rand(10)})
707717
with m:
708718
with pytest.warns(UserWarning, match=warning_msg):
709719
pm.sample_posterior_predictive(trace, samples=5)
@@ -768,7 +778,7 @@ def test_potentials_warning(self):
768778
p = pm.Potential("p", a + 1)
769779
obs = pm.Normal("obs", a, 1, observed=5)
770780

771-
trace = az.from_dict({"a": np.random.rand(10)})
781+
trace = az_from_dict({"a": np.random.rand(10)})
772782
with pytest.warns(UserWarning, match=warning_msg):
773783
pm.sample_posterior_predictive_w(samples=5, traces=[trace, trace], models=[m, m])
774784

@@ -1031,17 +1041,17 @@ def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture):
10311041
with pmodel:
10321042
pp = pm.sample_posterior_predictive([trace[15]], var_names=["d"])
10331043

1034-
@pytest.mark.xfail(reason="Arviz not refactored for v4")
10351044
def test_sample_from_xarray_prior(self, point_list_arg_bug_fixture):
10361045
pmodel, trace = point_list_arg_bug_fixture
10371046

10381047
with pmodel:
10391048
prior = pm.sample_prior_predictive(samples=20)
1049+
10401050
idat = pm.to_inference_data(trace, prior=prior)
1051+
10411052
with pmodel:
10421053
pp = pm.sample_posterior_predictive(idat.prior, var_names=["d"])
10431054

1044-
@pytest.mark.xfail(reason="Arviz not refactored for v4")
10451055
def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture):
10461056
pmodel, trace = point_list_arg_bug_fixture
10471057
idat = pm.to_inference_data(trace)

0 commit comments

Comments
 (0)