@@ -1622,7 +1622,7 @@ def compile_forward_sampling_function(
1622
1622
basic_rvs : Optional [List [Variable ]] = None ,
1623
1623
givens_dict : Optional [Dict [Variable , Any ]] = None ,
1624
1624
** kwargs ,
1625
- ) -> Callable [..., Union [np .ndarray , List [np .ndarray ]]]:
1625
+ ) -> Tuple [ Callable [..., Union [np .ndarray , List [np .ndarray ]]], Set [ Variable ]]:
1626
1626
"""Compile a function to draw samples, conditioned on the values of some variables.
1627
1627
1628
1628
The goal of this function is to walk the aesara computational graph from the list
@@ -1635,13 +1635,10 @@ def compile_forward_sampling_function(
1635
1635
1636
1636
- Variables in the outputs list
1637
1637
- ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``
1638
- - Basic RVs that are not in the ``vars_in_trace`` list
1638
+ - Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list
1639
1639
- Variables that are keys in the ``givens_dict``
1640
1640
- Variables that have volatile inputs
1641
1641
1642
- Where by basic RVs we mean ``Variable`` instances produced by a ``RandomVariable`` ``Op``
1643
- that are in the ``basic_rvs`` list.
1644
-
1645
1642
Concretely, this function can be used to compile a function to sample from the
1646
1643
posterior predictive distribution of a model that has variables that are conditioned
1647
1644
on ``MutableData`` instances. The variables that depend on the mutable data will be
@@ -1670,12 +1667,19 @@ def compile_forward_sampling_function(
1670
1667
output of ``model.basic_RVs``) should have a reference to the variables that should
1671
1668
be considered as random variable instances. This includes variables that have
1672
1669
a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or
1673
- Censored distributions. If ``None``, only pure random variables will be considered
1674
- as potential random variables.
1670
+ Censored distributions.
1675
1671
givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]]
1676
1672
A dictionary that maps tensor variables to the values that should be used to replace them
1677
1673
in the compiled function. The types of the key and value should match or an error will be
1678
1674
raised during compilation.
1675
+
1676
+ Returns
1677
+ -------
1678
+ function: Callable
1679
+ Compiled forward sampling Aesara function
1680
+ volatile_basic_rvs: Set of Variable
1681
+ Set of all basic_rvs that were considered volatile and will be resampled when
1682
+ the function is evaluated
1679
1683
"""
1680
1684
if givens_dict is None :
1681
1685
givens_dict = {}
@@ -1741,7 +1745,10 @@ def expand(node):
1741
1745
for node , value in givens_dict .items ()
1742
1746
]
1743
1747
1744
- return compile_pymc (inputs , fg .outputs , givens = givens , on_unused_input = "ignore" , ** kwargs )
1748
+ return (
1749
+ compile_pymc (inputs , fg .outputs , givens = givens , on_unused_input = "ignore" , ** kwargs ),
1750
+ set (basic_rvs ) & (volatile_nodes - set (givens_dict )), # Basic RVs that will be resampled
1751
+ )
1745
1752
1746
1753
1747
1754
def sample_posterior_predictive (
@@ -1900,7 +1907,6 @@ def sample_posterior_predictive(
1900
1907
vars_ = model .observed_RVs + model .auto_deterministics
1901
1908
1902
1909
indices = np .arange (samples )
1903
-
1904
1910
if progressbar :
1905
1911
indices = progress_bar (indices , total = samples , display = progressbar )
1906
1912
@@ -1923,17 +1929,17 @@ def sample_posterior_predictive(
1923
1929
compile_kwargs .setdefault ("allow_input_downcast" , True )
1924
1930
compile_kwargs .setdefault ("accept_inplace" , True )
1925
1931
1926
- sampler_fn = point_wrapper (
1927
- compile_forward_sampling_function (
1928
- outputs = vars_to_sample ,
1929
- vars_in_trace = vars_in_trace ,
1930
- basic_rvs = model .basic_RVs ,
1931
- givens_dict = None ,
1932
- random_seed = random_seed ,
1933
- ** compile_kwargs ,
1934
- )
1932
+ _sampler_fn , volatile_basic_rvs = compile_forward_sampling_function (
1933
+ outputs = vars_to_sample ,
1934
+ vars_in_trace = vars_in_trace ,
1935
+ basic_rvs = model .basic_RVs ,
1936
+ givens_dict = None ,
1937
+ random_seed = random_seed ,
1938
+ ** compile_kwargs ,
1935
1939
)
1936
-
1940
+ sampler_fn = point_wrapper (_sampler_fn )
1941
+ # All model variables have a name, but mypy does not know this
1942
+ _log .info (f"Sampling: { list (sorted (volatile_basic_rvs , key = lambda var : var .name ))} " ) # type: ignore
1937
1943
ppc_trace_t = _DefaultTrace (samples )
1938
1944
try :
1939
1945
if isinstance (_trace , MultiTrace ):
@@ -2242,7 +2248,7 @@ def sample_prior_predictive(
2242
2248
compile_kwargs .setdefault ("allow_input_downcast" , True )
2243
2249
compile_kwargs .setdefault ("accept_inplace" , True )
2244
2250
2245
- sampler_fn = compile_forward_sampling_function (
2251
+ sampler_fn , volatile_basic_rvs = compile_forward_sampling_function (
2246
2252
vars_to_sample ,
2247
2253
vars_in_trace = [],
2248
2254
basic_rvs = model .basic_RVs ,
@@ -2251,6 +2257,8 @@ def sample_prior_predictive(
2251
2257
** compile_kwargs ,
2252
2258
)
2253
2259
2260
+ # All model variables have a name, but mypy does not know this
2261
+ _log .info (f"Sampling: { list (sorted (volatile_basic_rvs , key = lambda var : var .name ))} " ) # type: ignore
2254
2262
values = zip (* (sampler_fn () for i in range (samples )))
2255
2263
2256
2264
data = {k : np .stack (v ) for k , v in zip (names , values )}
0 commit comments