Skip to content

Commit 6479ad4

Browse files
committed
Log sampled basic_RVs sample_*_predictive functions
1 parent afa7bbf commit 6479ad4

File tree

2 files changed

+165
-37
lines changed

2 files changed

+165
-37
lines changed

pymc/sampling.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,7 +1622,7 @@ def compile_forward_sampling_function(
16221622
basic_rvs: Optional[List[Variable]] = None,
16231623
givens_dict: Optional[Dict[Variable, Any]] = None,
16241624
**kwargs,
1625-
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
1625+
) -> Tuple[Callable[..., Union[np.ndarray, List[np.ndarray]]], Set[Variable]]:
16261626
"""Compile a function to draw samples, conditioned on the values of some variables.
16271627
16281628
The goal of this function is to walk the aesara computational graph from the list
@@ -1635,13 +1635,10 @@ def compile_forward_sampling_function(
16351635
16361636
- Variables in the outputs list
16371637
- ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``
1638-
- Basic RVs that are not in the ``vars_in_trace`` list
1638+
- Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list
16391639
- Variables that are keys in the ``givens_dict``
16401640
- Variables that have volatile inputs
16411641
1642-
Where by basic RVs we mean ``Variable`` instances produced by a ``RandomVariable`` ``Op``
1643-
that are in the ``basic_rvs`` list.
1644-
16451642
Concretely, this function can be used to compile a function to sample from the
16461643
posterior predictive distribution of a model that has variables that are conditioned
16471644
on ``MutableData`` instances. The variables that depend on the mutable data will be
@@ -1670,12 +1667,19 @@ def compile_forward_sampling_function(
16701667
output of ``model.basic_RVs``) should have a reference to the variables that should
16711668
be considered as random variable instances. This includes variables that have
16721669
a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or
1673-
Censored distributions. If ``None``, only pure random variables will be considered
1674-
as potential random variables.
1670+
Censored distributions.
16751671
givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]]
16761672
A dictionary that maps tensor variables to the values that should be used to replace them
16771673
in the compiled function. The types of the key and value should match or an error will be
16781674
raised during compilation.
1675+
1676+
Returns
1677+
-------
1678+
function: Callable
1679+
Compiled forward sampling Aesara function
1680+
volatile_basic_rvs: Set of Variable
1681+
Set of all basic_rvs that were considered volatile and will be resampled when
1682+
the function is evaluated
16791683
"""
16801684
if givens_dict is None:
16811685
givens_dict = {}
@@ -1741,7 +1745,10 @@ def expand(node):
17411745
for node, value in givens_dict.items()
17421746
]
17431747

1744-
return compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs)
1748+
return (
1749+
compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs),
1750+
set(basic_rvs) & (volatile_nodes - set(givens_dict)), # Basic RVs that will be resampled
1751+
)
17451752

17461753

17471754
def sample_posterior_predictive(
@@ -1900,7 +1907,6 @@ def sample_posterior_predictive(
19001907
vars_ = model.observed_RVs + model.auto_deterministics
19011908

19021909
indices = np.arange(samples)
1903-
19041910
if progressbar:
19051911
indices = progress_bar(indices, total=samples, display=progressbar)
19061912

@@ -1923,17 +1929,17 @@ def sample_posterior_predictive(
19231929
compile_kwargs.setdefault("allow_input_downcast", True)
19241930
compile_kwargs.setdefault("accept_inplace", True)
19251931

1926-
sampler_fn = point_wrapper(
1927-
compile_forward_sampling_function(
1928-
outputs=vars_to_sample,
1929-
vars_in_trace=vars_in_trace,
1930-
basic_rvs=model.basic_RVs,
1931-
givens_dict=None,
1932-
random_seed=random_seed,
1933-
**compile_kwargs,
1934-
)
1932+
_sampler_fn, volatile_basic_rvs = compile_forward_sampling_function(
1933+
outputs=vars_to_sample,
1934+
vars_in_trace=vars_in_trace,
1935+
basic_rvs=model.basic_RVs,
1936+
givens_dict=None,
1937+
random_seed=random_seed,
1938+
**compile_kwargs,
19351939
)
1936-
1940+
sampler_fn = point_wrapper(_sampler_fn)
1941+
# All model variables have a name, but mypy does not know this
1942+
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
19371943
ppc_trace_t = _DefaultTrace(samples)
19381944
try:
19391945
if isinstance(_trace, MultiTrace):
@@ -2242,7 +2248,7 @@ def sample_prior_predictive(
22422248
compile_kwargs.setdefault("allow_input_downcast", True)
22432249
compile_kwargs.setdefault("accept_inplace", True)
22442250

2245-
sampler_fn = compile_forward_sampling_function(
2251+
sampler_fn, volatile_basic_rvs = compile_forward_sampling_function(
22462252
vars_to_sample,
22472253
vars_in_trace=[],
22482254
basic_rvs=model.basic_RVs,
@@ -2251,6 +2257,8 @@ def sample_prior_predictive(
22512257
**compile_kwargs,
22522258
)
22532259

2260+
# All model variables have a name, but mypy does not know this
2261+
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
22542262
values = zip(*(sampler_fn() for i in range(samples)))
22552263

22562264
data = {k: np.stack(v) for k, v in zip(names, values)}

0 commit comments

Comments
 (0)