Skip to content

Commit b0d1066

Browse files
authored
Improve join_nonshared_inputs documentation (#6216)
1 parent 23c4834 commit b0d1066

File tree

3 files changed

+147
-36
lines changed

3 files changed

+147
-36
lines changed

pymc/aesaraf.py

Lines changed: 130 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
)
5858
from aesara.tensor.rewriting.basic import topo_constant_folding
5959
from aesara.tensor.rewriting.shape import ShapeFeature
60-
from aesara.tensor.sharedvar import SharedVariable
60+
from aesara.tensor.sharedvar import SharedVariable, TensorSharedVariable
6161
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
6262
from aesara.tensor.var import TensorConstant, TensorVariable
6363

@@ -535,55 +535,155 @@ def make_shared_replacements(point, vars, model):
535535

536536
def join_nonshared_inputs(
537537
point: Dict[str, np.ndarray],
538-
xs: List[TensorVariable],
539-
vars: List[TensorVariable],
540-
shared,
541-
make_shared: bool = False,
542-
):
538+
outputs: List[TensorVariable],
539+
inputs: List[TensorVariable],
540+
shared_inputs: Optional[Dict[TensorVariable, TensorSharedVariable]] = None,
541+
make_inputs_shared: bool = False,
542+
) -> Tuple[List[TensorVariable], TensorVariable]:
543543
"""
544-
Takes a list of Aesara Variables and joins their non shared inputs into a single input.
544+
Create new outputs and input TensorVariables where the non-shared inputs are joined
545+
in a single raveled vector input.
545546
546547
Parameters
547548
----------
548-
point: a sample point
549-
xs: list of Aesara tensors
550-
vars: list of variables to join
549+
point : dict of {str : array_like}
550+
Dictionary that maps each input variable name to a numerical variable. The values
551+
are used to extract the shape of each input variable to establish a correct
552+
mapping between joined and original inputs. The shape of each variable is
553+
assumed to be fixed.
554+
outputs : list of TensorVariable
555+
List of output TensorVariables whose non-shared inputs will be replaced
556+
by a joined vector input.
557+
inputs : list of TensorVariable
558+
List of input TensorVariables which will be replaced by a joined vector input.
559+
shared_inputs : dict of {TensorVariable : TensorSharedVariable}, optional
560+
Dict of TensorVariable and their associated TensorSharedVariable in
561+
subgraph replacement.
562+
make_inputs_shared : bool, default False
563+
Whether to make the joined vector input a shared variable.
551564
552565
Returns
553566
-------
554-
tensors, inarray
555-
tensors: list of same tensors but with inarray as input
556-
inarray: vector of inputs
567+
new_outputs : list of TensorVariable
568+
List of new outputs `outputs` TensorVariables that depend on `joined_inputs` and new shared variables as inputs.
569+
joined_inputs : TensorVariable
570+
Joined input vector TensorVariable for the `new_outputs`
571+
572+
Examples
573+
--------
574+
Join the inputs of a simple Aesara graph.
575+
576+
.. code-block:: python
577+
578+
import aesara.tensor as at
579+
import numpy as np
580+
581+
from pymc.aesaraf import join_nonshared_inputs
582+
583+
# Original non-shared inputs
584+
x = at.scalar("x")
585+
y = at.vector("y")
586+
# Original output
587+
out = x + y
588+
print(out.eval({x: np.array(1), y: np.array([1, 2, 3])})) # [2, 3, 4]
589+
590+
# New output and inputs
591+
[new_out], joined_inputs = join_nonshared_inputs(
592+
point={ # Only shapes matter
593+
"x": np.zeros(()),
594+
"y": np.zeros(3),
595+
},
596+
outputs=[out],
597+
inputs=[x, y],
598+
)
599+
print(new_out.eval({
600+
joined_inputs: np.array([1, 1, 2, 3]),
601+
})) # [2, 3, 4]
602+
603+
Join the input value variables of a model logp.
604+
605+
.. code-block:: python
606+
607+
import pymc as pm
608+
609+
with pm.Model() as model:
610+
mu_pop = pm.Normal("mu_pop")
611+
sigma_pop = pm.HalfNormal("sigma_pop")
612+
mu = pm.Normal("mu", mu_pop, sigma_pop, shape=(3, ))
613+
614+
y = pm.Normal("y", mu, 1.0, observed=[0, 1, 2])
615+
616+
print(model.compile_logp()({
617+
"mu_pop": 0,
618+
"sigma_pop_log__": 1,
619+
"mu": [0, 1, 2],
620+
})) # -12.691227342634292
621+
622+
initial_point = model.initial_point()
623+
inputs = model.value_vars
624+
625+
[logp], joined_inputs = join_nonshared_inputs(
626+
point=initial_point,
627+
outputs=[model.logp()],
628+
inputs=inputs,
629+
)
630+
631+
print(logp.eval({
632+
joined_inputs: [0, 1, 0, 1, 2],
633+
})) # -12.691227342634292
634+
635+
Same as above but with the `mu_pop` value variable being shared.
636+
637+
.. code-block:: python
638+
639+
from aesara import shared
640+
641+
mu_pop_input, *other_inputs = inputs
642+
shared_mu_pop_input = shared(0.0)
643+
644+
[logp], other_joined_inputs = join_nonshared_inputs(
645+
point=initial_point,
646+
outputs=[model.logp()],
647+
inputs=other_inputs,
648+
shared_inputs={
649+
mu_pop_input: shared_mu_pop_input
650+
},
651+
)
652+
653+
print(logp.eval({
654+
other_joined_inputs: [1, 0, 1, 2],
655+
})) # -12.691227342634292
557656
"""
558-
if not vars:
559-
raise ValueError("Empty list of variables.")
657+
if not inputs:
658+
raise ValueError("Empty list of input variables.")
560659

561-
joined = at.concatenate([var.ravel() for var in vars])
660+
raveled_inputs = at.concatenate([var.ravel() for var in inputs])
562661

563-
if not make_shared:
564-
tensor_type = joined.type
565-
inarray = tensor_type("inarray")
662+
if not make_inputs_shared:
663+
tensor_type = raveled_inputs.type
664+
joined_inputs = tensor_type("joined_inputs")
566665
else:
567-
if point is None:
568-
raise ValueError("A point is required when `make_shared` is True")
569-
joined_values = np.concatenate([point[var.name].ravel() for var in vars])
570-
inarray = aesara.shared(joined_values, "inarray")
666+
joined_values = np.concatenate([point[var.name].ravel() for var in inputs])
667+
joined_inputs = aesara.shared(joined_values, "joined_inputs")
571668

572669
if aesara.config.compute_test_value != "off":
573-
inarray.tag.test_value = joined.tag.test_value
670+
joined_inputs.tag.test_value = raveled_inputs.tag.test_value
574671

575-
replace = {}
672+
replace: Dict[TensorVariable, TensorVariable] = {}
576673
last_idx = 0
577-
for var in vars:
674+
for var in inputs:
578675
shape = point[var.name].shape
579676
arr_len = np.prod(shape, dtype=int)
580-
replace[var] = inarray[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype)
677+
replace[var] = joined_inputs[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype)
581678
last_idx += arr_len
582679

583-
replace.update(shared)
680+
if shared_inputs is not None:
681+
replace.update(shared_inputs)
584682

585-
xs_special = [aesara.clone_replace(x, replace, rebuild_strict=False) for x in xs]
586-
return xs_special, inarray
683+
new_outputs = [
684+
aesara.clone_replace(output, replace, rebuild_strict=False) for output in outputs
685+
]
686+
return new_outputs, joined_inputs
587687

588688

589689
class PointFunc:

pymc/smc/kernels.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,11 @@ def _logp_forw(point, out_vars, in_vars, shared):
579579
580580
Parameters
581581
----------
582-
out_vars: List
582+
out_vars : list
583583
containing :class:`pymc.Distribution` for the output variables
584-
in_vars: List
584+
in_vars : list
585585
containing :class:`pymc.Distribution` for the input variables
586-
shared: List
586+
shared : list
587587
containing :class:`aesara.tensor.Tensor` for depended shared data
588588
"""
589589

@@ -602,7 +602,9 @@ def _logp_forw(point, out_vars, in_vars, shared):
602602
out_vars = clone_replace(out_vars, replace_int_input, rebuild_strict=False)
603603
in_vars = new_in_vars
604604

605-
out_list, inarray0 = join_nonshared_inputs(point, out_vars, in_vars, shared)
605+
out_list, inarray0 = join_nonshared_inputs(
606+
point=point, outputs=out_vars, inputs=in_vars, shared_inputs=shared
607+
)
606608
f = compile_pymc([inarray0], out_list[0])
607609
f.trust_input = True
608610
return f

pymc/step_methods/metropolis.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
# limitations under the License.
1414
from typing import Any, Callable, Dict, List, Optional, Tuple
1515

16+
import aesara
1617
import numpy as np
1718
import numpy.random as nr
1819
import scipy.linalg
1920
import scipy.special
2021

22+
from aesara import tensor as at
2123
from aesara.graph.fg import MissingInputError
2224
from aesara.tensor.random.basic import BernoulliRV, CategoricalRV
2325

@@ -1052,8 +1054,15 @@ def sample_except(limit, excluded):
10521054
return candidate
10531055

10541056

1055-
def delta_logp(point, logp, vars, shared):
1056-
[logp0], inarray0 = join_nonshared_inputs(point, [logp], vars, shared)
1057+
def delta_logp(
1058+
point: Dict[str, np.ndarray],
1059+
logp: at.TensorVariable,
1060+
vars: List[at.TensorVariable],
1061+
shared: Dict[at.TensorVariable, at.sharedvar.TensorSharedVariable],
1062+
) -> aesara.compile.Function:
1063+
[logp0], inarray0 = join_nonshared_inputs(
1064+
point=point, outputs=[logp], inputs=vars, shared_inputs=shared
1065+
)
10571066

10581067
tensor_type = inarray0.type
10591068
inarray1 = tensor_type("inarray1")

0 commit comments

Comments
 (0)