20
20
21
21
import aesara
22
22
import aesara .tensor as aet
23
- import arviz as az
24
23
import numpy as np
25
24
import numpy .testing as npt
26
25
import pytest
27
26
28
27
from aesara import shared
28
+ from arviz import InferenceData
29
+ from arviz import from_dict as az_from_dict
29
30
from scipy import stats
30
31
31
32
import pymc3 as pm
@@ -200,7 +201,7 @@ def test_return_inferencedata(self, monkeypatch):
200
201
201
202
# inferencedata with tuning
202
203
result = pm .sample (** kwargs , return_inferencedata = True , discard_tuned_samples = False )
203
- assert isinstance (result , az . InferenceData )
204
+ assert isinstance (result , InferenceData )
204
205
assert result .posterior .sizes ["draw" ] == 100
205
206
assert result .posterior .sizes ["chain" ] == 2
206
207
assert len (result ._groups_warmup ) > 0
@@ -215,7 +216,7 @@ def test_return_inferencedata(self, monkeypatch):
215
216
random_seed = - 1
216
217
)
217
218
assert "prior" in result
218
- assert isinstance (result , az . InferenceData )
219
+ assert isinstance (result , InferenceData )
219
220
assert result .posterior .sizes ["draw" ] == 100
220
221
assert result .posterior .sizes ["chain" ] == 2
221
222
assert len (result ._groups_warmup ) == 0
@@ -458,20 +459,26 @@ def test_normal_scalar(self):
458
459
ppc = pm .sample_posterior_predictive (trace , size = 5 , var_names = ["a" ])
459
460
assert ppc ["a" ].shape == (nchains * ndraws , 5 )
460
461
461
- @pytest .mark .xfail (reason = "Arviz not refactored for v4" )
462
462
def test_normal_scalar_idata (self ):
463
463
nchains = 2
464
464
ndraws = 500
465
465
with pm .Model () as model :
466
466
mu = pm .Normal ("mu" , 0.0 , 1.0 )
467
467
a = pm .Normal ("a" , mu = mu , sigma = 1 , observed = 0.0 )
468
468
trace = pm .sample (
469
- draws = ndraws , chains = nchains , return_inferencedata = True , discard_tuned_samples = False
469
+ draws = ndraws ,
470
+ chains = nchains ,
471
+ return_inferencedata = False ,
472
+ discard_tuned_samples = False ,
470
473
)
471
474
475
+ assert not isinstance (trace , InferenceData )
476
+
472
477
with model :
473
478
# test keep_size parameter and idata input
474
479
idata = pm .to_inference_data (trace )
480
+ assert isinstance (idata , InferenceData )
481
+
475
482
ppc = pm .sample_posterior_predictive (idata , keep_size = True )
476
483
assert ppc ["a" ].shape == (nchains , ndraws )
477
484
@@ -505,16 +512,19 @@ def test_normal_vector(self, caplog):
505
512
assert "a" in ppc
506
513
assert ppc ["a" ].shape == (10 , 4 , 2 )
507
514
508
- @pytest .mark .xfail (reason = "Arviz not refactored for v4" )
509
515
def test_normal_vector_idata (self , caplog ):
510
516
with pm .Model () as model :
511
517
mu = pm .Normal ("mu" , 0.0 , 1.0 )
512
518
a = pm .Normal ("a" , mu = mu , sigma = 1 , observed = np .array ([0.5 , 0.2 ]))
513
519
trace = pm .sample (return_inferencedata = False )
514
520
521
+ assert not isinstance (trace , InferenceData )
522
+
515
523
with model :
516
524
# test keep_size parameter with inference data as input...
517
525
idata = pm .to_inference_data (trace )
526
+ assert isinstance (idata , InferenceData )
527
+
518
528
ppc = pm .sample_posterior_predictive (idata , keep_size = True )
519
529
assert ppc ["a" ].shape == (trace .nchains , len (trace ), 2 )
520
530
@@ -703,7 +713,7 @@ def test_potentials_warning(self):
703
713
p = pm .Potential ("p" , a + 1 )
704
714
obs = pm .Normal ("obs" , a , 1 , observed = 5 )
705
715
706
- trace = az . from_dict ({"a" : np .random .rand (10 )})
716
+ trace = az_from_dict ({"a" : np .random .rand (10 )})
707
717
with m :
708
718
with pytest .warns (UserWarning , match = warning_msg ):
709
719
pm .sample_posterior_predictive (trace , samples = 5 )
@@ -768,7 +778,7 @@ def test_potentials_warning(self):
768
778
p = pm .Potential ("p" , a + 1 )
769
779
obs = pm .Normal ("obs" , a , 1 , observed = 5 )
770
780
771
- trace = az . from_dict ({"a" : np .random .rand (10 )})
781
+ trace = az_from_dict ({"a" : np .random .rand (10 )})
772
782
with pytest .warns (UserWarning , match = warning_msg ):
773
783
pm .sample_posterior_predictive_w (samples = 5 , traces = [trace , trace ], models = [m , m ])
774
784
@@ -1031,17 +1041,17 @@ def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture):
1031
1041
with pmodel :
1032
1042
pp = pm .sample_posterior_predictive ([trace [15 ]], var_names = ["d" ])
1033
1043
1034
- @pytest .mark .xfail (reason = "Arviz not refactored for v4" )
1035
1044
def test_sample_from_xarray_prior (self , point_list_arg_bug_fixture ):
1036
1045
pmodel , trace = point_list_arg_bug_fixture
1037
1046
1038
1047
with pmodel :
1039
1048
prior = pm .sample_prior_predictive (samples = 20 )
1049
+
1040
1050
idat = pm .to_inference_data (trace , prior = prior )
1051
+
1041
1052
with pmodel :
1042
1053
pp = pm .sample_posterior_predictive (idat .prior , var_names = ["d" ])
1043
1054
1044
- @pytest .mark .xfail (reason = "Arviz not refactored for v4" )
1045
1055
def test_sample_from_xarray_posterior (self , point_list_arg_bug_fixture ):
1046
1056
pmodel , trace = point_list_arg_bug_fixture
1047
1057
idat = pm .to_inference_data (trace )
0 commit comments