Skip to content

Commit 13dfeb2

Browse files
dfmmichaelosthege
andauthored
Adding option to include transformed variables in InferenceData (#6232)
* Adding option to include transformed variables in InferenceData Co-authored-by: Michael Osthege <[email protected]>
1 parent b5b63f5 commit 13dfeb2

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

pymc/backends/arviz.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,11 @@ def __init__(
164164
dims: Optional[DimSpec] = None,
165165
model=None,
166166
save_warmup: Optional[bool] = None,
167+
include_transformed: bool = False,
167168
):
168169

169170
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
171+
self.include_transformed = include_transformed
170172
self.trace = trace
171173

172174
# this permits us to get the model from command-line argument or from with model:
@@ -311,7 +313,9 @@ def _extract_log_likelihood(self, trace):
311313
@requires("trace")
312314
def posterior_to_xarray(self):
313315
"""Convert the posterior to an xarray dataset."""
314-
var_names = get_default_varnames(self.trace.varnames, include_transformed=False)
316+
var_names = get_default_varnames(
317+
self.trace.varnames, include_transformed=self.include_transformed
318+
)
315319
data = {}
316320
data_warmup = {}
317321
for var_name in var_names:
@@ -539,6 +543,7 @@ def to_inference_data(
539543
dims: Optional[DimSpec] = None,
540544
model: Optional["Model"] = None,
541545
save_warmup: Optional[bool] = None,
546+
include_transformed: bool = False,
542547
) -> InferenceData:
543548
"""Convert pymc data into an InferenceData object.
544549
@@ -571,6 +576,9 @@ def to_inference_data(
571576
save_warmup : bool, optional
572577
Save warmup iterations InferenceData object. If not defined, use default
573578
defined by the rcParams.
579+
include_transformed : bool, optional
580+
Save the transformed parameters in the InferenceData object. By default, these are
581+
not saved.
574582
575583
Returns
576584
-------
@@ -588,6 +596,7 @@ def to_inference_data(
588596
dims=dims,
589597
model=model,
590598
save_warmup=save_warmup,
599+
include_transformed=include_transformed,
591600
).to_inference_data()
592601

593602

pymc/tests/backends/test_arviz.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def test_autodetect_coords_from_model(self, use_context):
279279
np.testing.assert_array_equal(idata.observed_data.coords["date"], coords["date"])
280280
np.testing.assert_array_equal(idata.observed_data.coords["city"], coords["city"])
281281

282-
def test_ovewrite_model_coords_dims(self):
282+
def test_overwrite_model_coords_dims(self):
283283
"""Check coords and dims from model object can be partially overwritten."""
284284
dim1 = ["a", "b"]
285285
new_dim1 = ["c", "d"]
@@ -617,6 +617,23 @@ def test_variable_dimension_name_collision(self):
617617
var = at.as_tensor([1, 2, 3])
618618
pmodel.register_rv(var, name="time", dims=("time",))
619619

620+
def test_include_transformed(self):
621+
with pm.Model():
622+
pm.Uniform("p", 0, 1)
623+
624+
# First check that the default is to exclude the transformed variables
625+
sample_kwargs = dict(tune=5, draws=7, chains=2, cores=1)
626+
inference_data = pm.sample(**sample_kwargs, step=pm.Metropolis())
627+
assert "p_interval__" not in inference_data.posterior
628+
629+
# Now check that they are included when requested
630+
inference_data = pm.sample(
631+
**sample_kwargs,
632+
step=pm.Metropolis(),
633+
idata_kwargs={"include_transformed": True},
634+
)
635+
assert "p_interval__" in inference_data.posterior
636+
620637

621638
class TestPyMCWarmupHandling:
622639
@pytest.mark.parametrize("save_warmup", [False, True])

0 commit comments

Comments
 (0)