|
57 | 57 | )
|
58 | 58 | from aesara.tensor.rewriting.basic import topo_constant_folding
|
59 | 59 | from aesara.tensor.rewriting.shape import ShapeFeature
|
60 |
| -from aesara.tensor.sharedvar import SharedVariable |
| 60 | +from aesara.tensor.sharedvar import SharedVariable, TensorSharedVariable |
61 | 61 | from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
|
62 | 62 | from aesara.tensor.var import TensorConstant, TensorVariable
|
63 | 63 |
|
@@ -535,55 +535,155 @@ def make_shared_replacements(point, vars, model):
|
535 | 535 |
|
536 | 536 | def join_nonshared_inputs(
|
537 | 537 | point: Dict[str, np.ndarray],
|
538 |
| - xs: List[TensorVariable], |
539 |
| - vars: List[TensorVariable], |
540 |
| - shared, |
541 |
| - make_shared: bool = False, |
542 |
| -): |
| 538 | + outputs: List[TensorVariable], |
| 539 | + inputs: List[TensorVariable], |
| 540 | + shared_inputs: Optional[Dict[TensorVariable, TensorSharedVariable]] = None, |
| 541 | + make_inputs_shared: bool = False, |
| 542 | +) -> Tuple[List[TensorVariable], TensorVariable]: |
543 | 543 | """
|
544 |
| - Takes a list of Aesara Variables and joins their non shared inputs into a single input. |
| 544 | + Create new outputs and input TensorVariables where the non-shared inputs are joined |
| 545 | + in a single raveled vector input. |
545 | 546 |
|
546 | 547 | Parameters
|
547 | 548 | ----------
|
548 |
| - point: a sample point |
549 |
| - xs: list of Aesara tensors |
550 |
| - vars: list of variables to join |
| 549 | + point : dict of {str : array_like} |
| 550 | + Dictionary that maps each input variable name to a numerical variable. The values |
| 551 | + are used to extract the shape of each input variable to establish a correct |
| 552 | + mapping between joined and original inputs. The shape of each variable is |
| 553 | + assumed to be fixed. |
| 554 | + outputs : list of TensorVariable |
| 555 | + List of output TensorVariables whose non-shared inputs will be replaced |
| 556 | + by a joined vector input. |
| 557 | + inputs : list of TensorVariable |
| 558 | + List of input TensorVariables which will be replaced by a joined vector input. |
| 559 | + shared_inputs : dict of {TensorVariable : TensorSharedVariable}, optional |
| 560 | + Dict of TensorVariable and their associated TensorSharedVariable in |
| 561 | + subgraph replacement. |
| 562 | + make_inputs_shared : bool, default False |
| 563 | + Whether to make the joined vector input a shared variable. |
551 | 564 |
|
552 | 565 | Returns
|
553 | 566 | -------
|
554 |
| - tensors, inarray |
555 |
| - tensors: list of same tensors but with inarray as input |
556 |
| - inarray: vector of inputs |
| 567 | + new_outputs : list of TensorVariable |
| 568 | + List of new outputs `outputs` TensorVariables that depend on `joined_inputs` and new shared variables as inputs. |
| 569 | + joined_inputs : TensorVariable |
| 570 | + Joined input vector TensorVariable for the `new_outputs` |
| 571 | +
|
| 572 | + Examples |
| 573 | + -------- |
| 574 | + Join the inputs of a simple Aesara graph. |
| 575 | +
|
| 576 | + .. code-block:: python |
| 577 | +
|
| 578 | + import aesara.tensor as at |
| 579 | + import numpy as np |
| 580 | +
|
| 581 | + from pymc.aesaraf import join_nonshared_inputs |
| 582 | +
|
| 583 | + # Original non-shared inputs |
| 584 | + x = at.scalar("x") |
| 585 | + y = at.vector("y") |
| 586 | + # Original output |
| 587 | + out = x + y |
| 588 | + print(out.eval({x: np.array(1), y: np.array([1, 2, 3])})) # [2, 3, 4] |
| 589 | +
|
| 590 | + # New output and inputs |
| 591 | + [new_out], joined_inputs = join_nonshared_inputs( |
| 592 | + point={ # Only shapes matter |
| 593 | + "x": np.zeros(()), |
| 594 | + "y": np.zeros(3), |
| 595 | + }, |
| 596 | + outputs=[out], |
| 597 | + inputs=[x, y], |
| 598 | + ) |
| 599 | + print(new_out.eval({ |
| 600 | + joined_inputs: np.array([1, 1, 2, 3]), |
| 601 | + })) # [2, 3, 4] |
| 602 | +
|
| 603 | + Join the input value variables of a model logp. |
| 604 | +
|
| 605 | + .. code-block:: python |
| 606 | +
|
| 607 | + import pymc as pm |
| 608 | +
|
| 609 | + with pm.Model() as model: |
| 610 | + mu_pop = pm.Normal("mu_pop") |
| 611 | + sigma_pop = pm.HalfNormal("sigma_pop") |
| 612 | + mu = pm.Normal("mu", mu_pop, sigma_pop, shape=(3, )) |
| 613 | +
|
| 614 | + y = pm.Normal("y", mu, 1.0, observed=[0, 1, 2]) |
| 615 | +
|
| 616 | + print(model.compile_logp()({ |
| 617 | + "mu_pop": 0, |
| 618 | + "sigma_pop_log__": 1, |
| 619 | + "mu": [0, 1, 2], |
| 620 | + })) # -12.691227342634292 |
| 621 | +
|
| 622 | + initial_point = model.initial_point() |
| 623 | + inputs = model.value_vars |
| 624 | +
|
| 625 | + [logp], joined_inputs = join_nonshared_inputs( |
| 626 | + point=initial_point, |
| 627 | + outputs=[model.logp()], |
| 628 | + inputs=inputs, |
| 629 | + ) |
| 630 | +
|
| 631 | + print(logp.eval({ |
| 632 | + joined_inputs: [0, 1, 0, 1, 2], |
| 633 | + })) # -12.691227342634292 |
| 634 | +
|
| 635 | + Same as above but with the `mu_pop` value variable being shared. |
| 636 | +
|
| 637 | + .. code-block:: python |
| 638 | +
|
| 639 | + from aesara import shared |
| 640 | +
|
| 641 | + mu_pop_input, *other_inputs = inputs |
| 642 | + shared_mu_pop_input = shared(0.0) |
| 643 | +
|
| 644 | + [logp], other_joined_inputs = join_nonshared_inputs( |
| 645 | + point=initial_point, |
| 646 | + outputs=[model.logp()], |
| 647 | + inputs=other_inputs, |
| 648 | + shared_inputs={ |
| 649 | + mu_pop_input: shared_mu_pop_input |
| 650 | + }, |
| 651 | + ) |
| 652 | +
|
| 653 | + print(logp.eval({ |
| 654 | + other_joined_inputs: [1, 0, 1, 2], |
| 655 | + })) # -12.691227342634292 |
557 | 656 | """
|
558 |
| - if not vars: |
559 |
| - raise ValueError("Empty list of variables.") |
| 657 | + if not inputs: |
| 658 | + raise ValueError("Empty list of input variables.") |
560 | 659 |
|
561 |
| - joined = at.concatenate([var.ravel() for var in vars]) |
| 660 | + raveled_inputs = at.concatenate([var.ravel() for var in inputs]) |
562 | 661 |
|
563 |
| - if not make_shared: |
564 |
| - tensor_type = joined.type |
565 |
| - inarray = tensor_type("inarray") |
| 662 | + if not make_inputs_shared: |
| 663 | + tensor_type = raveled_inputs.type |
| 664 | + joined_inputs = tensor_type("joined_inputs") |
566 | 665 | else:
|
567 |
| - if point is None: |
568 |
| - raise ValueError("A point is required when `make_shared` is True") |
569 |
| - joined_values = np.concatenate([point[var.name].ravel() for var in vars]) |
570 |
| - inarray = aesara.shared(joined_values, "inarray") |
| 666 | + joined_values = np.concatenate([point[var.name].ravel() for var in inputs]) |
| 667 | + joined_inputs = aesara.shared(joined_values, "joined_inputs") |
571 | 668 |
|
572 | 669 | if aesara.config.compute_test_value != "off":
|
573 |
| - inarray.tag.test_value = joined.tag.test_value |
| 670 | + joined_inputs.tag.test_value = raveled_inputs.tag.test_value |
574 | 671 |
|
575 |
| - replace = {} |
| 672 | + replace: Dict[TensorVariable, TensorVariable] = {} |
576 | 673 | last_idx = 0
|
577 |
| - for var in vars: |
| 674 | + for var in inputs: |
578 | 675 | shape = point[var.name].shape
|
579 | 676 | arr_len = np.prod(shape, dtype=int)
|
580 |
| - replace[var] = inarray[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype) |
| 677 | + replace[var] = joined_inputs[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype) |
581 | 678 | last_idx += arr_len
|
582 | 679 |
|
583 |
| - replace.update(shared) |
| 680 | + if shared_inputs is not None: |
| 681 | + replace.update(shared_inputs) |
584 | 682 |
|
585 |
| - xs_special = [aesara.clone_replace(x, replace, rebuild_strict=False) for x in xs] |
586 |
| - return xs_special, inarray |
| 683 | + new_outputs = [ |
| 684 | + aesara.clone_replace(output, replace, rebuild_strict=False) for output in outputs |
| 685 | + ] |
| 686 | + return new_outputs, joined_inputs |
587 | 687 |
|
588 | 688 |
|
589 | 689 | class PointFunc:
|
|
0 commit comments