Skip to content

Log basic_RVs sampled in sample_*_predictive functions #6142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,7 +1622,7 @@ def compile_forward_sampling_function(
basic_rvs: Optional[List[Variable]] = None,
givens_dict: Optional[Dict[Variable, Any]] = None,
**kwargs,
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
) -> Tuple[Callable[..., Union[np.ndarray, List[np.ndarray]]], Set[Variable]]:
"""Compile a function to draw samples, conditioned on the values of some variables.

The goal of this function is to walk the aesara computational graph from the list
Expand All @@ -1635,13 +1635,10 @@ def compile_forward_sampling_function(

- Variables in the outputs list
- ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``
- Basic RVs that are not in the ``vars_in_trace`` list
- Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list
- Variables that are keys in the ``givens_dict``
- Variables that have volatile inputs

Where by basic RVs we mean ``Variable`` instances produced by a ``RandomVariable`` ``Op``
that are in the ``basic_rvs`` list.

Concretely, this function can be used to compile a function to sample from the
posterior predictive distribution of a model that has variables that are conditioned
on ``MutableData`` instances. The variables that depend on the mutable data will be
Expand Down Expand Up @@ -1670,12 +1667,19 @@ def compile_forward_sampling_function(
output of ``model.basic_RVs``) should have a reference to the variables that should
be considered as random variable instances. This includes variables that have
a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or
Censored distributions. If ``None``, only pure random variables will be considered
as potential random variables.
Censored distributions.
givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]]
A dictionary that maps tensor variables to the values that should be used to replace them
in the compiled function. The types of the key and value should match or an error will be
raised during compilation.

Returns
-------
function: Callable
Compiled forward sampling Aesara function
volatile_basic_rvs: Set of Variable
Set of all basic_rvs that were considered volatile and will be resampled when
the function is evaluated
"""
if givens_dict is None:
givens_dict = {}
Expand Down Expand Up @@ -1741,7 +1745,10 @@ def expand(node):
for node, value in givens_dict.items()
]

return compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs)
return (
compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs),
set(basic_rvs) & (volatile_nodes - set(givens_dict)), # Basic RVs that will be resampled
)


def sample_posterior_predictive(
Expand Down Expand Up @@ -1900,7 +1907,6 @@ def sample_posterior_predictive(
vars_ = model.observed_RVs + model.auto_deterministics

indices = np.arange(samples)

if progressbar:
indices = progress_bar(indices, total=samples, display=progressbar)

Expand All @@ -1923,17 +1929,17 @@ def sample_posterior_predictive(
compile_kwargs.setdefault("allow_input_downcast", True)
compile_kwargs.setdefault("accept_inplace", True)

sampler_fn = point_wrapper(
compile_forward_sampling_function(
outputs=vars_to_sample,
vars_in_trace=vars_in_trace,
basic_rvs=model.basic_RVs,
givens_dict=None,
random_seed=random_seed,
**compile_kwargs,
)
_sampler_fn, volatile_basic_rvs = compile_forward_sampling_function(
outputs=vars_to_sample,
vars_in_trace=vars_in_trace,
basic_rvs=model.basic_RVs,
givens_dict=None,
random_seed=random_seed,
**compile_kwargs,
)

sampler_fn = point_wrapper(_sampler_fn)
# All model variables have a name, but mypy does not know this
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
ppc_trace_t = _DefaultTrace(samples)
try:
if isinstance(_trace, MultiTrace):
Expand Down Expand Up @@ -2242,7 +2248,7 @@ def sample_prior_predictive(
compile_kwargs.setdefault("allow_input_downcast", True)
compile_kwargs.setdefault("accept_inplace", True)

sampler_fn = compile_forward_sampling_function(
sampler_fn, volatile_basic_rvs = compile_forward_sampling_function(
vars_to_sample,
vars_in_trace=[],
basic_rvs=model.basic_RVs,
Expand All @@ -2251,6 +2257,8 @@ def sample_prior_predictive(
**compile_kwargs,
)

# All model variables have a name, but mypy does not know this
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
values = zip(*(sampler_fn() for i in range(samples)))

data = {k: np.stack(v) for k, v in zip(names, values)}
Expand Down
Loading