From 4c720cb9dedb523eeaeafd81da21d73c0e1f0889 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Fri, 8 Apr 2022 17:28:52 +0200 Subject: [PATCH 1/3] Do not use tag in Mixture classmethods --- pymc/distributions/mixture.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 5fe1b8a374..909ac6a1c1 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -294,11 +294,6 @@ def rv_op(cls, weights, *components, size=None, rngs=None): # inside OpFromGraph and PyMC will never find it otherwise mix_indexes_rng.default_update = mix_out.owner.outputs[0] - # Reference nodes to facilitate identification in other classmethods - mix_out.tag.weights = weights - mix_out.tag.components = components - mix_out.tag.choices_rng = mix_indexes_rng - # Component RVs terms are accounted by the Mixture logprob, so they can be # safely ignore by Aeppl (this tag prevents UserWarning) for component in components: @@ -337,15 +332,14 @@ def ndim_supp(cls, weights, *components): @classmethod def change_size(cls, rv, new_size, expand=False): - weights = rv.tag.weights - components = rv.tag.components - rngs = [component.owner.inputs[0] for component in components] + [rv.tag.choices_rng] + mix_indexes_rng, weights, *components = rv.owner.inputs + rngs = [component.owner.inputs[0] for component in components] + [mix_indexes_rng] if expand: - component = rv.tag.components[0] + component = components[0] # Old size is equal to `shape[:-ndim_supp]`, with care needed for `ndim_supp == 0` size_dims = component.ndim - component.owner.op.ndim_supp - if len(rv.tag.components) == 1: + if len(components) == 1: # If we have a single component, new size should ignore the mixture axis # dimension, as that is not touched by `_resize_components` size_dims -= 1 @@ -362,7 +356,7 @@ def graph_rvs(cls, rv): # mix_indexes_ RV in its inner graph. We want super().dist() to generate # (components + 1) rngs for us, and it will do so based on how many elements # we return here - return (*rv.tag.components, rv) + return (*rv.owner.inputs[2:], rv) @_get_measurable_outputs.register(MarginalMixtureRV) From 042531212df14a1d7a51b58d7fea1273ea20ee99 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Fri, 8 Apr 2022 18:21:40 +0200 Subject: [PATCH 2/3] Create dispatched ndim_supp_dist --- pymc/distributions/censored.py | 13 ++++++++----- pymc/distributions/distribution.py | 18 +++++++++++++----- pymc/distributions/mixture.py | 28 +++++++++++++++------------- pymc/distributions/multivariate.py | 5 +++-- pymc/distributions/shape_utils.py | 25 ++++++++++++++++++++++++- pymc/distributions/timeseries.py | 4 ++-- pymc/model.py | 7 +++++++ 7 files changed, 72 insertions(+), 28 deletions(-) diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py index 60edd7e227..084f9685c5 100644 --- a/pymc/distributions/censored.py +++ b/pymc/distributions/censored.py @@ -19,6 +19,7 @@ from aesara.tensor.random.op import RandomVariable from pymc.distributions.distribution import SymbolicDistribution, _moment +from pymc.distributions.shape_utils import _ndim_supp_dist, ndim_supp_dist from pymc.util import check_dist_not_registered @@ -65,7 +66,7 @@ def dist(cls, dist, lower, upper, **kwargs): raise ValueError( f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}" ) - if dist.owner.op.ndim_supp > 0: + if ndim_supp_dist(dist) > 0: raise NotImplementedError( "Censoring of multivariate distributions has not been implemented yet" ) @@ -95,10 +96,6 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None): return rv_out - @classmethod - def ndim_supp(cls, *dist_params): - return 0 - @classmethod def change_size(cls, rv, new_size, expand=False): dist_node = rv.tag.dist.owner @@ -124,6 +121,12 @@ def graph_rvs(cls, rv): return (rv.tag.dist,) +@_ndim_supp_dist.register(Clip) +def ndim_supp_censored(op, dist): + # We only support Censoring of univariate distributions + return 0 + + @_moment.register(Clip) def moment_censored(op, rv, dist, lower, upper): moment = at.switch( diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index ac3541380b..a57adcedf5 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -44,6 +44,7 @@ convert_shape, convert_size, find_size, + ndim_supp_dist, resize_from_dims, resize_from_observed, ) @@ -399,16 +400,20 @@ def __new__( cls.rv_op Returns a TensorVariable that represents the symbolic distribution parametrized by a default set of parameters and a size and rngs arguments - cls.ndim_supp - Returns the support of the symbolic distribution, given the default - parameters. This may not always be constant, for instance if the symbolic - distribution can be defined based on an arbitrary base distribution. cls.change_size Returns an equivalent symbolic distribution with a different size. This is analogous to `pymc.aesaraf.change_rv_size` for `RandomVariable`s. cls.graph_rvs Returns base RVs in a symbolic distribution. + Furthermore, Censored distributions must have a dispatch version of the following + functions for correct behavior in PyMC: + _ndim_supp_dist + Returns the support of the symbolic distribution. This may not always be + constant, for instance if the symbolic distribution can be defined based + on an arbitrary base distribution. This is called by + `pymc.distributions.shape_utils.ndim_supp_dist` + Parameters ---------- cls : type @@ -559,8 +564,11 @@ def dist( shape = convert_shape(shape) size = convert_size(size) + # Create a temporary dist to obtain the ndim_supp + ndim_supp = ndim_supp_dist(cls.rv_op(*dist_params, size=size)) + create_size, ndim_expected, ndim_batch, ndim_supp = find_size( - shape=shape, size=size, ndim_supp=cls.ndim_supp(*dist_params) + shape=shape, size=size, ndim_supp=ndim_supp ) # Create the RV with a `size` right away. # This is not necessarily the final result. diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 909ac6a1c1..9625649708 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -31,7 +31,7 @@ from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import SymbolicDistribution, _moment, moment from pymc.distributions.logprob import logcdf, logp -from pymc.distributions.shape_utils import to_tuple +from pymc.distributions.shape_utils import _ndim_supp_dist, ndim_supp_dist, to_tuple from pymc.distributions.transforms import _default_transform from pymc.util import check_dist_not_registered from pymc.vartypes import continuous_types, discrete_types @@ -188,7 +188,7 @@ def dist(cls, w, comp_dists, **kwargs): f"Component dist must be a distribution created via the `.dist()` API, got {type(dist)}" ) check_dist_not_registered(dist) - components_ndim_supp.add(dist.owner.op.ndim_supp) + components_ndim_supp.add(ndim_supp_dist(dist)) if len(components_ndim_supp) > 1: raise ValueError( @@ -209,7 +209,7 @@ def rv_op(cls, weights, *components, size=None, rngs=None): mix_indexes_rng = aesara.shared(np.random.default_rng()) single_component = len(components) == 1 - ndim_supp = components[0].owner.op.ndim_supp + ndim_supp = ndim_supp_dist(components[0]) if size is not None: components = cls._resize_components(size, *components) @@ -319,17 +319,12 @@ def _resize_components(cls, size, *components): if len(components) == 1: # If we have a single component, we need to keep the length of the mixture # axis intact, because that's what determines the number of mixture components - mix_axis = -components[0].owner.op.ndim_supp - 1 + mix_axis = -ndim_supp_dist(components[0]) - 1 mix_size = components[0].shape[mix_axis] size = tuple(size) + (mix_size,) return [change_rv_size(component, size) for component in components] - @classmethod - def ndim_supp(cls, weights, *components): - # We already checked that all components have the same support dimensionality - return components[0].owner.op.ndim_supp - @classmethod def change_size(cls, rv, new_size, expand=False): mix_indexes_rng, weights, *components = rv.owner.inputs @@ -338,7 +333,7 @@ def change_size(cls, rv, new_size, expand=False): if expand: component = components[0] # Old size is equal to `shape[:-ndim_supp]`, with care needed for `ndim_supp == 0` - size_dims = component.ndim - component.owner.op.ndim_supp + size_dims = component.ndim - ndim_supp_dist(component) if len(components) == 1: # If we have a single component, new size should ignore the mixture axis # dimension, as that is not touched by `_resize_components` @@ -359,6 +354,13 @@ def graph_rvs(cls, rv): return (*rv.owner.inputs[2:], rv) +@_ndim_supp_dist.register(MarginalMixtureRV) +def ndim_supp_marginal_mixture(op, rv): + # We already checked that all components have the same support dimensionality + components = rv.owner.inputs[2:] + return ndim_supp_dist(components[0]) + + @_get_measurable_outputs.register(MarginalMixtureRV) def _get_measurable_outputs_MarginalMixtureRV(op, node): # This tells Aeppl that the second output is the measurable one @@ -372,7 +374,7 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs): # single component if len(components) == 1: # Need to broadcast value across mixture axis - mix_axis = -components[0].owner.op.ndim_supp - 1 + mix_axis = -ndim_supp_dist(components[0]) - 1 components_logp = logp(components[0], at.expand_dims(value, mix_axis)) else: components_logp = at.stack( @@ -405,7 +407,7 @@ def marginal_mixture_logcdf(op, value, rng, weights, *components, **kwargs): # single component if len(components) == 1: # Need to broadcast value across mixture axis - mix_axis = -components[0].owner.op.ndim_supp - 1 + mix_axis = -ndim_supp_dist(components[0]) - 1 components_logcdf = logcdf(components[0], at.expand_dims(value, mix_axis)) else: components_logcdf = at.stack( @@ -434,7 +436,7 @@ def marginal_mixture_logcdf(op, value, rng, weights, *components, **kwargs): @_moment.register(MarginalMixtureRV) def marginal_mixture_moment(op, rv, rng, weights, *components): - ndim_supp = components[0].owner.op.ndim_supp + ndim_supp = ndim_supp_dist(components[0]) weights = at.shape_padright(weights, ndim_supp) mix_axis = -ndim_supp - 1 diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 354061bab7..5943a2e71c 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -59,6 +59,7 @@ from pymc.distributions.distribution import Continuous, Discrete, moment from pymc.distributions.shape_utils import ( broadcast_dist_samples_to, + ndim_supp_dist, rv_size_is_none, to_tuple, ) @@ -1187,7 +1188,7 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs): isinstance(sd_dist, Variable) and sd_dist.owner is not None and isinstance(sd_dist.owner.op, RandomVariable) - and sd_dist.owner.op.ndim_supp < 2 + and ndim_supp_dist(sd_dist) < 2 ): raise TypeError("sd_dist must be a scalar or vector distribution variable") @@ -1197,7 +1198,7 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs): # diagonal element. # Since `eta` and `n` are forced to be scalars we don't need to worry about # implied batched dimensions for the time being. - if sd_dist.owner.op.ndim_supp == 0: + if ndim_supp_dist(sd_dist) == 0: sd_dist = change_rv_size(sd_dist, to_tuple(size) + (n,)) else: # The support shape must be `n` but we have no way of controlling it diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 8110b3d280..248968f62e 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -17,12 +17,15 @@ A collection of common shape operations needed for broadcasting samples from probability distributions for stochastic nodes in PyMC. """ - +from functools import singledispatch from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union, cast import numpy as np from aesara.graph.basic import Variable +from aesara.graph.op import Op +from aesara.tensor.elemwise import Elemwise +from aesara.tensor.random.op import RandomVariable from aesara.tensor.var import TensorVariable from typing_extensions import TypeAlias @@ -619,3 +622,23 @@ def find_size( def rv_size_is_none(size: Variable) -> bool: """Check wether an rv size is None (ie., at.Constant([]))""" return size.type.shape == (0,) # type: ignore [attr-defined] + + +@singledispatch +def _ndim_supp_dist(op: Op, dist: TensorVariable) -> int: + raise TypeError(f"ndim_supp not known for Op {op}") + + +def ndim_supp_dist(dist: TensorVariable) -> int: + return _ndim_supp_dist(dist.owner.op, dist) + + +@_ndim_supp_dist.register(RandomVariable) +def ndim_supp_rv(op: Op, rv: TensorVariable): + return op.ndim_supp + + +@_ndim_supp_dist.register(Elemwise) +def ndim_supp_elemwise(op: Op, *args, **kwargs): + """For Elemwise Ops, dispatch on respective scalar_op""" + return _ndim_supp_dist(op.scalar_op, *args, **kwargs) diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index bf6addb900..b6516207db 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -24,7 +24,7 @@ from pymc.distributions import distribution, logprob, multivariate from pymc.distributions.continuous import Flat, Normal, get_tau_sigma from pymc.distributions.dist_math import check_parameters -from pymc.distributions.shape_utils import to_tuple +from pymc.distributions.shape_utils import ndim_supp_dist, to_tuple from pymc.util import check_dist_not_registered __all__ = [ @@ -175,7 +175,7 @@ def dist( isinstance(init, at.TensorVariable) and init.owner is not None and isinstance(init.owner.op, RandomVariable) - and init.owner.op.ndim_supp == 0 + and ndim_supp_dist(init) == 0 ): raise TypeError("init must be a univariate distribution variable") diff --git a/pymc/model.py b/pymc/model.py index 55303752aa..38be73fd7f 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -43,6 +43,7 @@ from aesara.compile.sharedvalue import SharedVariable from aesara.graph.basic import Constant, Variable, graph_inputs from aesara.graph.fg import FunctionGraph +from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.opt import local_subtensor_rv_lift from aesara.tensor.random.var import RandomStateSharedVariable from aesara.tensor.sharedvar import ScalarSharedVariable @@ -1330,6 +1331,12 @@ def make_obs_var( ) warnings.warn(impute_message, ImputationWarning) + # TODO: Add test for this + if not isinstance(rv_var.owner.op, RandomVariable): + raise NotImplementedError( + f"Automatic inputation is only supported for RandomVariables, but {rv_var} is of type {rv_var.owner.op}" + ) + if rv_var.owner.op.ndim_supp > 0: raise NotImplementedError( f"Automatic inputation is only supported for univariate RandomVariables, but {rv_var} is multivariate" From 4702bd413376e6bd01c34b10228619e906875a9d Mon Sep 17 00:00:00 2001 From: Ricardo Date: Fri, 8 Apr 2022 17:53:12 +0200 Subject: [PATCH 3/3] Create dispatched resize_dist --- pymc/aesaraf.py | 64 +--------------- pymc/distributions/censored.py | 28 ++++--- pymc/distributions/distribution.py | 12 +-- pymc/distributions/mixture.py | 52 +++++++------ pymc/distributions/multivariate.py | 7 +- pymc/distributions/shape_utils.py | 98 ++++++++++++++++++++++++- pymc/distributions/timeseries.py | 8 +- pymc/sampling.py | 5 +- pymc/tests/test_aesaraf.py | 73 +----------------- pymc/tests/test_distributions.py | 8 +- pymc/tests/test_distributions_random.py | 9 +-- pymc/tests/test_mixture.py | 10 +-- pymc/tests/test_shape_handling.py | 73 ++++++++++++++++++ 13 files changed, 246 insertions(+), 201 deletions(-) diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index b263a20b9b..34795c5871 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -31,7 +31,7 @@ import scipy.sparse as sps from aeppl.logprob import CheckParameterValue -from aesara import config, scalar +from aesara import scalar from aesara.compile.mode import Mode, get_mode from aesara.gradient import grad from aesara.graph import local_optimizer @@ -45,17 +45,15 @@ walk, ) from aesara.graph.fg import FunctionGraph -from aesara.graph.op import Op, compute_test_value +from aesara.graph.op import Op from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream from aesara.scalar.basic import Cast from aesara.tensor.elemwise import Elemwise from aesara.tensor.random.op import RandomVariable -from aesara.tensor.shape import SpecifyShape from aesara.tensor.sharedvar import SharedVariable from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 from aesara.tensor.var import TensorConstant, TensorVariable -from pymc.exceptions import ShapeError from pymc.vartypes import continuous_types, int_types, isgenerator, typefilter PotentialShapeType = Union[ @@ -142,64 +140,6 @@ def pandas_to_array(data): return floatX(ret) -def change_rv_size( - rv: TensorVariable, - new_size: PotentialShapeType, - expand: Optional[bool] = False, -) -> TensorVariable: - """Change or expand the size of a `RandomVariable`. - - Parameters - ========== - rv - The old `RandomVariable` output. - new_size - The new size. - expand: - Expand the existing size by `new_size`. - - """ - # Check the dimensionality of the `new_size` kwarg - new_size_ndim = np.ndim(new_size) - if new_size_ndim > 1: - raise ShapeError("The `new_size` must be ≤1-dimensional.", actual=new_size_ndim) - elif new_size_ndim == 0: - new_size = (new_size,) - - # Extract the RV node that is to be resized, together with its inputs, name and tag - if isinstance(rv.owner.op, SpecifyShape): - rv = rv.owner.inputs[0] - rv_node = rv.owner - rng, size, dtype, *dist_params = rv_node.inputs - name = rv.name - tag = rv.tag - - if expand: - shape = tuple(rv_node.op._infer_shape(size, dist_params)) - size = shape[: len(shape) - rv_node.op.ndim_supp] - new_size = tuple(new_size) + tuple(size) - - # Make sure the new size is a tensor. This dtype-aware conversion helps - # to not unnecessarily pick up a `Cast` in some cases (see #4652). - new_size = at.as_tensor(new_size, ndim=1, dtype="int64") - - new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params) - new_rv = new_rv_node.outputs[-1] - new_rv.name = name - for k, v in tag.__dict__.items(): - new_rv.tag.__dict__.setdefault(k, v) - - # Update "traditional" rng default_update, if that was set for old RV - default_update = getattr(rng, "default_update", None) - if default_update is not None and default_update is rv_node.outputs[0]: - rng.default_update = new_rv_node.outputs[0] - - if config.compute_test_value != "off": - compute_test_value(new_rv_node) - - return new_rv - - def extract_rv_and_value_vars( var: TensorVariable, ) -> Tuple[TensorVariable, TensorVariable]: diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py index 084f9685c5..651d2223ec 100644 --- a/pymc/distributions/censored.py +++ b/pymc/distributions/censored.py @@ -19,7 +19,12 @@ from aesara.tensor.random.op import RandomVariable from pymc.distributions.distribution import SymbolicDistribution, _moment -from pymc.distributions.shape_utils import _ndim_supp_dist, ndim_supp_dist +from pymc.distributions.shape_utils import ( + _ndim_supp_dist, + _resize_dist, + ndim_supp_dist, + resize_dist, +) from pymc.util import check_dist_not_registered @@ -90,22 +95,12 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None): rv_out.tag.upper = upper if size is not None: - rv_out = cls.change_size(rv_out, size) + rv_out = resize_dist(rv_out, size) if rngs is not None: rv_out = cls.change_rngs(rv_out, rngs) return rv_out - @classmethod - def change_size(cls, rv, new_size, expand=False): - dist_node = rv.tag.dist.owner - lower = rv.tag.lower - upper = rv.tag.upper - rng, old_size, dtype, *dist_params = dist_node.inputs - new_size = new_size if not expand else tuple(new_size) + tuple(old_size) - new_dist = dist_node.op.make_node(rng, new_size, dtype, *dist_params).default_output() - return cls.rv_op(new_dist, lower, upper) - @classmethod def change_rngs(cls, rv, new_rngs): (new_rng,) = new_rngs @@ -127,6 +122,15 @@ def ndim_supp_censored(op, dist): return 0 +@_resize_dist.register(Clip) +def resize_censored(op, rv, new_size, expand=False): + dist = rv.tag.dist + lower = rv.tag.lower + upper = rv.tag.upper + new_dist = resize_dist(dist, new_size, expand=expand) + return Censored.rv_op(new_dist, lower, upper) + + @_moment.register(Clip) def moment_censored(op, rv, dist, lower, upper): moment = at.switch( diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index a57adcedf5..292228fecd 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -33,7 +33,6 @@ from aesara.tensor.var import TensorVariable from typing_extensions import TypeAlias -from pymc.aesaraf import change_rv_size from pymc.distributions.shape_utils import ( Dims, Shape, @@ -45,6 +44,7 @@ convert_size, find_size, ndim_supp_dist, + resize_dist, resize_from_dims, resize_from_observed, ) @@ -270,7 +270,7 @@ def __new__( if resize_shape: # A batch size was specified through `dims`, or implied by `observed`. - rv_out = change_rv_size(rv=rv_out, new_size=resize_shape, expand=True) + rv_out = resize_dist(dist=rv_out, new_size=resize_shape, expand=True) rv_out = model.register_rv( rv_out, @@ -356,7 +356,7 @@ def dist( # Replicate dimensions may be prepended via a shape with Ellipsis as the last element: if shape is not None and Ellipsis in shape: replicate_shape = cast(StrongShape, shape[:-1]) - rv_out = change_rv_size(rv=rv_out, new_size=replicate_shape, expand=True) + rv_out = resize_dist(dist=rv_out, new_size=replicate_shape, expand=True) rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)") rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)") @@ -400,9 +400,6 @@ def __new__( cls.rv_op Returns a TensorVariable that represents the symbolic distribution parametrized by a default set of parameters and a size and rngs arguments - cls.change_size - Returns an equivalent symbolic distribution with a different size. This is - analogous to `pymc.aesaraf.change_rv_size` for `RandomVariable`s. cls.graph_rvs Returns base RVs in a symbolic distribution. @@ -413,6 +410,9 @@ def __new__( constant, for instance if the symbolic distribution can be defined based on an arbitrary base distribution. This is called by `pymc.distributions.shape_utils.ndim_supp_dist` + _resize_dist + Returns an equivalent symbolic distribution with a different size. This is + called by `pymc.distrributions.shape_utils.resize_dist`. Parameters ---------- diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 9625649708..b4e21d8865 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -25,13 +25,18 @@ from aesara.tensor import TensorVariable from aesara.tensor.random.op import RandomVariable -from pymc.aesaraf import change_rv_size from pymc.distributions import transforms from pymc.distributions.continuous import Normal, get_tau_sigma from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import SymbolicDistribution, _moment, moment from pymc.distributions.logprob import logcdf, logp -from pymc.distributions.shape_utils import _ndim_supp_dist, ndim_supp_dist, to_tuple +from pymc.distributions.shape_utils import ( + _ndim_supp_dist, + _resize_dist, + ndim_supp_dist, + resize_dist, + to_tuple, +) from pymc.distributions.transforms import _default_transform from pymc.util import check_dist_not_registered from pymc.vartypes import continuous_types, discrete_types @@ -323,27 +328,7 @@ def _resize_components(cls, size, *components): mix_size = components[0].shape[mix_axis] size = tuple(size) + (mix_size,) - return [change_rv_size(component, size) for component in components] - - @classmethod - def change_size(cls, rv, new_size, expand=False): - mix_indexes_rng, weights, *components = rv.owner.inputs - rngs = [component.owner.inputs[0] for component in components] + [mix_indexes_rng] - - if expand: - component = components[0] - # Old size is equal to `shape[:-ndim_supp]`, with care needed for `ndim_supp == 0` - size_dims = component.ndim - ndim_supp_dist(component) - if len(components) == 1: - # If we have a single component, new size should ignore the mixture axis - # dimension, as that is not touched by `_resize_components` - size_dims -= 1 - old_size = components[0].shape[:size_dims] - new_size = to_tuple(new_size) + tuple(old_size) - - components = cls._resize_components(new_size, *components) - - return cls.rv_op(weights, *components, rngs=rngs, size=None) + return [resize_dist(component, size) for component in components] @classmethod def graph_rvs(cls, rv): @@ -361,6 +346,27 @@ def ndim_supp_marginal_mixture(op, rv): return ndim_supp_dist(components[0]) +@_resize_dist.register(MarginalMixtureRV) +def resize_marginal_mixture(op, rv, new_size, expand=False): + mix_indexes_rng, weights, *components = rv.owner.inputs + rngs = [component.owner.inputs[0] for component in components] + [mix_indexes_rng] + + if expand: + component = components[0] + # Old size is equal to `shape[:-ndim_supp]`, with care needed for `ndim_supp == 0` + size_dims = component.ndim - ndim_supp_dist(component) + if len(components) == 1: + # If we have a single component, new size should ignore the mixture axis + # dimension, as that is not touched by `_resize_components` + size_dims -= 1 + old_size = components[0].shape[:size_dims] + new_size = to_tuple(new_size) + tuple(old_size) + + components = Mixture._resize_components(new_size, *components) + + return Mixture.rv_op(weights, *components, rngs=rngs, size=None) + + @_get_measurable_outputs.register(MarginalMixtureRV) def _get_measurable_outputs_MarginalMixtureRV(op, node): # This tells Aeppl that the second output is the measurable one diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 5943a2e71c..4fdf5a6fe4 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -41,7 +41,7 @@ import pymc as pm -from pymc.aesaraf import change_rv_size, floatX, intX +from pymc.aesaraf import floatX, intX from pymc.distributions import transforms from pymc.distributions.continuous import ( BoundedContinuous, @@ -60,6 +60,7 @@ from pymc.distributions.shape_utils import ( broadcast_dist_samples_to, ndim_supp_dist, + resize_dist, rv_size_is_none, to_tuple, ) @@ -1199,10 +1200,10 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs): # Since `eta` and `n` are forced to be scalars we don't need to worry about # implied batched dimensions for the time being. if ndim_supp_dist(sd_dist) == 0: - sd_dist = change_rv_size(sd_dist, to_tuple(size) + (n,)) + sd_dist = resize_dist(sd_dist, to_tuple(size) + (n,)) else: # The support shape must be `n` but we have no way of controlling it - sd_dist = change_rv_size(sd_dist, to_tuple(size)) + sd_dist = resize_dist(sd_dist, to_tuple(size)) # sd_dist is part of the generative graph, but should be completely ignored # by the logp graph, since the LKJ logp explicitly includes these terms. diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 248968f62e..8361ced1d1 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -22,14 +22,18 @@ import numpy as np +from aesara import config +from aesara import tensor as at from aesara.graph.basic import Variable -from aesara.graph.op import Op +from aesara.graph.op import Op, compute_test_value from aesara.tensor.elemwise import Elemwise from aesara.tensor.random.op import RandomVariable +from aesara.tensor.shape import SpecifyShape from aesara.tensor.var import TensorVariable from typing_extensions import TypeAlias -from pymc.aesaraf import pandas_to_array +from pymc.exceptions import ShapeError +from pymc.aesaraf import PotentialShapeType, pandas_to_array __all__ = [ "to_tuple", @@ -642,3 +646,93 @@ def ndim_supp_rv(op: Op, rv: TensorVariable): def ndim_supp_elemwise(op: Op, *args, **kwargs): """For Elemwise Ops, dispatch on respective scalar_op""" return _ndim_supp_dist(op.scalar_op, *args, **kwargs) + + +@singledispatch +def _resize_dist( + op: Op, dist: TensorVariable, new_size: PotentialShapeType, expand: Optional[bool] = False +) -> TensorVariable: + raise NotImplementedError(f"resize not implemented for Op {op}") + + +def resize_dist( + dist: TensorVariable, new_size: PotentialShapeType, expand: Optional[bool] = False +) -> TensorVariable: + """Change or expand the size of a Distribution variable. + + Parameters + ========== + dist + The old distibution output. + new_size + The new size. + expand + Expand the existing size by `new_size`. + """ + new_size_ndim = np.ndim(new_size) + if new_size_ndim > 1: + raise ShapeError("The `new_size` must be ≤1-dimensional.", actual=new_size_ndim) + elif new_size_ndim == 0: + new_size = (new_size,) + + new_dist = _resize_dist(dist.owner.op, dist, new_size, expand) + + new_dist.name = dist.name + for k, v in dist.tag.__dict__.items(): + new_dist.tag.__dict__.setdefault(k, v) + + if config.compute_test_value != "off": + compute_test_value(new_dist.owner) + + return new_dist + + +@_resize_dist.register(RandomVariable) +def resize_rv( + op: Op, + rv: TensorVariable, + new_size: PotentialShapeType, + expand: Optional[bool] = False, +) -> TensorVariable: + """Change or expand the size of a `RandomVariable`. + + Parameters + ========== + rv + The old `RandomVariable` output. + new_size + The new size. + expand: + Expand the existing size by `new_size`. + + """ + # Extract the RV node that is to be resized, together with its inputs, name and tag + if isinstance(rv.owner.op, SpecifyShape): + rv = rv.owner.inputs[0] + rv_node = rv.owner + rng, size, dtype, *dist_params = rv_node.inputs + + if expand: + shape = tuple(rv_node.op._infer_shape(size, dist_params)) + size = shape[: len(shape) - rv_node.op.ndim_supp] + new_size = tuple(new_size) + tuple(size) + + # Make sure the new size is a tensor. This dtype-aware conversion helps + # to not unnecessarily pick up a `Cast` in some cases (see #4652). + new_size = at.as_tensor(new_size, ndim=1, dtype="int64") + + new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params) + new_rv = new_rv_node.outputs[-1] + + # Update "traditional" rng default_update, if that was set for old RV + default_update = getattr(rng, "default_update", None) + if default_update is not None and default_update is rv_node.outputs[0]: + rng.default_update = new_rv_node.outputs[0] + + return new_rv + + +@_resize_dist.register(Elemwise) +def resize_elemwise(op: Op, *args, **kwargs): + """For Elemwise Ops, dispatch on respective scalar_op""" + return _resize_dist(op.scalar_op, *args, **kwargs) diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index b6516207db..c8cade6c1c 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -20,11 +20,11 @@ from aesara import scan from aesara.tensor.random.op import RandomVariable -from pymc.aesaraf import change_rv_size, floatX, intX +from pymc.aesaraf import floatX, intX from pymc.distributions import distribution, logprob, multivariate from pymc.distributions.continuous import Flat, Normal, get_tau_sigma from pymc.distributions.dist_math import check_parameters -from pymc.distributions.shape_utils import ndim_supp_dist, to_tuple +from pymc.distributions.shape_utils import ndim_supp_dist, resize_dist, to_tuple from pymc.util import check_dist_not_registered __all__ = [ @@ -180,11 +180,11 @@ def dist( raise TypeError("init must be a univariate distribution variable") if init_size is not None: - init = change_rv_size(init, init_size) + init = resize_dist(init, init_size) else: # If not explicit, size is determined by the shapes of mu, sigma, and init bcast_shape = at.broadcast_arrays(mu, sigma, init)[0].shape - init = change_rv_size(init, bcast_shape) + init = resize_dist(init, bcast_shape) # Ignores logprob of init var because that's accounted for in the logp method init.tag.ignore_logprob = True diff --git a/pymc/sampling.py b/pymc/sampling.py index b6e02dc885..e3336f1f8a 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -52,11 +52,12 @@ import pymc as pm -from pymc.aesaraf import change_rv_size, compile_pymc, inputvars, walk_model +from pymc.aesaraf import compile_pymc, inputvars, walk_model from pymc.backends.arviz import _DefaultTrace from pymc.backends.base import BaseTrace, MultiTrace from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection +from pymc.distributions.shape_utils import resize_dist from pymc.exceptions import IncorrectArgumentsError, SamplingError from pymc.initial_point import ( PointType, @@ -1743,7 +1744,7 @@ def sample_posterior_predictive( inputs = [model[n] for n in input_names] if size is not None: - vars_to_sample = [change_rv_size(v, size, expand=True) for v in vars_to_sample] + vars_to_sample = [resize_dist(v, size, expand=True) for v in vars_to_sample] if compile_kwargs is None: compile_kwargs = {} diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index 363437490e..db70c05321 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -23,7 +23,7 @@ import scipy.sparse as sps from aeppl.logprob import ParameterValueError -from aesara.graph.basic import Constant, Variable, ancestors, equal_computations +from aesara.graph.basic import Variable, equal_computations from aesara.tensor.random.basic import normal, uniform from aesara.tensor.random.op import RandomVariable from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 @@ -34,7 +34,6 @@ from pymc.aesaraf import ( _conversion_map, - change_rv_size, compile_pymc, extract_obs_data, pandas_to_array, @@ -43,79 +42,9 @@ walk_model, ) from pymc.distributions.dist_math import check_parameters -from pymc.exceptions import ShapeError from pymc.vartypes import int_types -def test_change_rv_size(): - loc = at.as_tensor_variable([1, 2]) - rv = normal(loc=loc) - assert rv.ndim == 1 - assert tuple(rv.shape.eval()) == (2,) - - with pytest.raises(ShapeError, match="must be ≤1-dimensional"): - change_rv_size(rv, new_size=[[2, 3]]) - with pytest.raises(ShapeError, match="must be ≤1-dimensional"): - change_rv_size(rv, new_size=at.as_tensor_variable([[2, 3], [4, 5]])) - - rv_new = change_rv_size(rv, new_size=(3,), expand=True) - assert rv_new.ndim == 2 - assert tuple(rv_new.shape.eval()) == (3, 2) - - # Make sure that the shape used to determine the expanded size doesn't - # depend on the old `RandomVariable`. - rv_new_ancestors = set(ancestors((rv_new,))) - assert loc in rv_new_ancestors - assert rv not in rv_new_ancestors - - rv_newer = change_rv_size(rv_new, new_size=(4,), expand=True) - assert rv_newer.ndim == 3 - assert tuple(rv_newer.shape.eval()) == (4, 3, 2) - - # Make sure we avoid introducing a `Cast` by converting the new size before - # constructing the new `RandomVariable` - rv = normal(0, 1) - new_size = np.array([4, 3], dtype="int32") - rv_newer = change_rv_size(rv, new_size=new_size, expand=False) - assert rv_newer.ndim == 2 - assert isinstance(rv_newer.owner.inputs[1], Constant) - assert tuple(rv_newer.shape.eval()) == (4, 3) - - rv = normal(0, 1) - new_size = at.as_tensor(np.array([4, 3], dtype="int32")) - rv_newer = change_rv_size(rv, new_size=new_size, expand=True) - assert rv_newer.ndim == 2 - assert tuple(rv_newer.shape.eval()) == (4, 3) - - rv = normal(0, 1) - new_size = at.as_tensor(2, dtype="int32") - rv_newer = change_rv_size(rv, new_size=new_size, expand=True) - assert rv_newer.ndim == 1 - assert tuple(rv_newer.shape.eval()) == (2,) - - -def test_change_rv_size_default_update(): - rng = aesara.shared(np.random.default_rng(0)) - x = normal(rng=rng) - - # Test that "traditional" default_update is updated - rng.default_update = x.owner.outputs[0] - new_x = change_rv_size(x, new_size=(2,)) - assert rng.default_update is not x.owner.outputs[0] - assert rng.default_update is new_x.owner.outputs[0] - - # Test that "non-traditional" default_update is left unchanged - next_rng = aesara.shared(np.random.default_rng(1)) - rng.default_update = next_rng - new_x = change_rv_size(x, new_size=(2,)) - assert rng.default_update is next_rng - - # Test that default_update is not set if there was none before - del rng.default_update - new_x = change_rv_size(x, new_size=(2,)) - assert not hasattr(rng, "default_update") - - class TestBroadcasting: def test_make_shared_replacements(self): """Check if pm.make_shared_replacements preserves broadcasting.""" diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index 4f2c86e457..a3e5c44f3a 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -126,7 +126,7 @@ def polyagamma_cdf(*args, **kwargs): logcdf, logp, ) -from pymc.distributions.shape_utils import to_tuple +from pymc.distributions.shape_utils import resize_dist, to_tuple from pymc.math import kronecker from pymc.model import Deterministic, Model, Point, Potential from pymc.tests.helpers import select_by_precision @@ -3371,13 +3371,13 @@ def test_censored_invalid_dist(self): ): x = pm.Censored("x", registered_dist, lower=None, upper=None) - def test_change_size(self): + def test_resize_dist(self): base_dist = pm.Censored.dist(pm.Normal.dist(), -1, 1, size=(3, 2)) - new_dist = pm.Censored.change_size(base_dist, (4,)) + new_dist = resize_dist(base_dist, (4,)) assert new_dist.eval().shape == (4,) - new_dist = pm.Censored.change_size(base_dist, (4,), expand=True) + new_dist = resize_dist(base_dist, (4,), expand=True) assert new_dist.eval().shape == (4, 3, 2) diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 9fd7c2ce81..e11fe3923d 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -42,7 +42,7 @@ def random_polyagamma(*args, **kwargs): import pymc as pm -from pymc.aesaraf import change_rv_size, compile_pymc, floatX, intX +from pymc.aesaraf import compile_pymc, floatX, intX from pymc.distributions.continuous import get_tau_sigma, interpolated from pymc.distributions.discrete import _OrderedLogistic, _OrderedProbit from pymc.distributions.dist_math import clipped_beta_rvs @@ -52,7 +52,7 @@ def random_polyagamma(*args, **kwargs): _OrderedMultinomial, quaddist_matrix, ) -from pymc.distributions.shape_utils import to_tuple +from pymc.distributions.shape_utils import resize_dist, to_tuple from pymc.tests.helpers import SeededTest, select_by_precision from pymc.tests.test_distributions import ( Domain, @@ -74,7 +74,6 @@ def pymc_random( fails=10, extra_args=None, model_args=None, - change_rv_size_fn=change_rv_size, ): if valuedomain is None: valuedomain = Domain([0], edges=(None, None)) @@ -83,7 +82,7 @@ def pymc_random( model_args = {} model, param_vars = build_model(dist, valuedomain, paramdomains, extra_args) - model_dist = change_rv_size_fn(model.named_vars["value"], size, expand=True) + model_dist = resize_dist(model.named_vars["value"], size, expand=True) pymc_rand = compile_pymc([], model_dist) domains = paramdomains.copy() @@ -122,7 +121,7 @@ def pymc_random_discrete( valuedomain = Domain([0], edges=(None, None)) model, param_vars = build_model(dist, valuedomain, paramdomains) - model_dist = change_rv_size(model.named_vars["value"], size, expand=True) + model_dist = resize_dist(model.named_vars["value"], size, expand=True) pymc_rand = compile_pymc([], model_dist) domains = paramdomains.copy() diff --git a/pymc/tests/test_mixture.py b/pymc/tests/test_mixture.py index 2cab2151ae..d1babe3a6e 100644 --- a/pymc/tests/test_mixture.py +++ b/pymc/tests/test_mixture.py @@ -50,7 +50,7 @@ ) from pymc.distributions.logprob import logp from pymc.distributions.mixture import MixtureTransformWarning -from pymc.distributions.shape_utils import to_tuple +from pymc.distributions.shape_utils import resize_dist, to_tuple from pymc.distributions.transforms import _default_transform from pymc.math import expand_packed_triangular from pymc.model import Model @@ -384,18 +384,18 @@ def test_components_expanded_by_weights(self, comp_dists): ), ) @pytest.mark.parametrize("expand", (False, True)) - def test_change_size(self, comp_dists, expand): + def test_resize_dist(self, comp_dists, expand): univariate = comp_dists[0].owner.op.ndim_supp == 0 mix = Mixture.dist(w=Dirichlet.dist([1, 1]), comp_dists=comp_dists) - mix = Mixture.change_size(mix, new_size=(4,), expand=expand) + mix = resize_dist(mix, new_size=(4,), expand=expand) draws = mix.eval() expected_shape = (4,) if univariate else (4, 3) assert draws.shape == expected_shape assert np.unique(draws).size == draws.size mix = Mixture.dist(w=Dirichlet.dist([1, 1]), comp_dists=comp_dists, size=(3,)) - mix = Mixture.change_size(mix, new_size=(5, 4), expand=expand) + mix = resize_dist(mix, new_size=(5, 4), expand=expand) draws = mix.eval() expected_shape = (5, 4) if univariate else (5, 4, 3) if expand: @@ -829,7 +829,6 @@ def ref_rand(size, w, mu, sigma): extra_args={"comp_shape": 2}, size=1000, ref_rand=ref_rand, - change_rv_size_fn=Mixture.change_size, ) pymc_random( NormalMixture, @@ -841,7 +840,6 @@ def ref_rand(size, w, mu, sigma): extra_args={"comp_shape": 3}, size=1000, ref_rand=ref_rand, - change_rv_size_fn=Mixture.change_size, ) diff --git a/pymc/tests/test_shape_handling.py b/pymc/tests/test_shape_handling.py index 77fd878ca3..e1172454bb 100644 --- a/pymc/tests/test_shape_handling.py +++ b/pymc/tests/test_shape_handling.py @@ -17,9 +17,12 @@ import pytest from aesara import tensor as at +from aesara.graph import Constant, ancestors +from aesara.tensor.random import normal import pymc as pm +from pymc import ShapeError from pymc.distributions.shape_utils import ( broadcast_dist_samples_shape, broadcast_dist_samples_to, @@ -28,6 +31,7 @@ convert_shape, convert_size, get_broadcastable_dist_samples, + resize_rv, shapes_broadcasting, to_tuple, ) @@ -476,3 +480,72 @@ def test_size_from_observed_rng_update(self): # Confirm that the rng is properly offset, otherwise the second value of the first # draw, would match the first value of the second draw assert fn()[1] != fn()[0] + + +def test_resize_dist_rv(): + loc = at.as_tensor_variable([1, 2]) + rv = normal(loc=loc) + assert rv.ndim == 1 + assert tuple(rv.shape.eval()) == (2,) + + with pytest.raises(ShapeError, match="must be ≤1-dimensional"): + resize_rv(rv, new_size=[[2, 3]]) + with pytest.raises(ShapeError, match="must be ≤1-dimensional"): + resize_rv(rv, new_size=at.as_tensor_variable([[2, 3], [4, 5]])) + + rv_new = resize_rv(rv, new_size=(3,), expand=True) + assert rv_new.ndim == 2 + assert tuple(rv_new.shape.eval()) == (3, 2) + + # Make sure that the shape used to determine the expanded size doesn't + # depend on the old `RandomVariable`. + rv_new_ancestors = set(ancestors((rv_new,))) + assert loc in rv_new_ancestors + assert rv not in rv_new_ancestors + + rv_newer = resize_rv(rv_new, new_size=(4,), expand=True) + assert rv_newer.ndim == 3 + assert tuple(rv_newer.shape.eval()) == (4, 3, 2) + + # Make sure we avoid introducing a `Cast` by converting the new size before + # constructing the new `RandomVariable` + rv = normal(0, 1) + new_size = np.array([4, 3], dtype="int32") + rv_newer = resize_rv(rv, new_size=new_size, expand=False) + assert rv_newer.ndim == 2 + assert isinstance(rv_newer.owner.inputs[1], Constant) + assert tuple(rv_newer.shape.eval()) == (4, 3) + + rv = normal(0, 1) + new_size = at.as_tensor(np.array([4, 3], dtype="int32")) + rv_newer = resize_rv(rv, new_size=new_size, expand=True) + assert rv_newer.ndim == 2 + assert tuple(rv_newer.shape.eval()) == (4, 3) + + rv = normal(0, 1) + new_size = at.as_tensor(2, dtype="int32") + rv_newer = resize_rv(rv, new_size=new_size, expand=True) + assert rv_newer.ndim == 1 + assert tuple(rv_newer.shape.eval()) == (2,) + + +def test_resize_dist_rv_default_update(): + rng = aesara.shared(np.random.default_rng(0)) + x = normal(rng=rng) + + # Test that "traditional" default_update is updated + rng.default_update = x.owner.outputs[0] + new_x = resize_rv(x, new_size=(2,)) + assert rng.default_update is not x.owner.outputs[0] + assert rng.default_update is new_x.owner.outputs[0] + + # Test that "non-traditional" default_update is left unchanged + next_rng = aesara.shared(np.random.default_rng(1)) + rng.default_update = next_rng + new_x = resize_rv(x, new_size=(2,)) + assert rng.default_update is next_rng + + # Test that default_update is not set if there was none before + del rng.default_update + new_x = resize_rv(x, new_size=(2,)) + assert not hasattr(rng, "default_update")