Skip to content

Create dispatch functions for resize_dist and ndim_supp_dist #5702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 2 additions & 62 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import scipy.sparse as sps

from aeppl.logprob import CheckParameterValue
from aesara import config, scalar
from aesara import scalar
from aesara.compile.mode import Mode, get_mode
from aesara.gradient import grad
from aesara.graph import local_optimizer
Expand All @@ -45,17 +45,15 @@
walk,
)
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op, compute_test_value
from aesara.graph.op import Op
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream
from aesara.scalar.basic import Cast
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import SpecifyShape
from aesara.tensor.sharedvar import SharedVariable
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
from aesara.tensor.var import TensorConstant, TensorVariable

from pymc.exceptions import ShapeError
from pymc.vartypes import continuous_types, int_types, isgenerator, typefilter

PotentialShapeType = Union[
Expand Down Expand Up @@ -142,64 +140,6 @@ def pandas_to_array(data):
return floatX(ret)


def change_rv_size(
rv: TensorVariable,
new_size: PotentialShapeType,
expand: Optional[bool] = False,
) -> TensorVariable:
"""Change or expand the size of a `RandomVariable`.

Parameters
==========
rv
The old `RandomVariable` output.
new_size
The new size.
expand:
Expand the existing size by `new_size`.

"""
# Check the dimensionality of the `new_size` kwarg
new_size_ndim = np.ndim(new_size)
if new_size_ndim > 1:
raise ShapeError("The `new_size` must be ≤1-dimensional.", actual=new_size_ndim)
elif new_size_ndim == 0:
new_size = (new_size,)

# Extract the RV node that is to be resized, together with its inputs, name and tag
if isinstance(rv.owner.op, SpecifyShape):
rv = rv.owner.inputs[0]
rv_node = rv.owner
rng, size, dtype, *dist_params = rv_node.inputs
name = rv.name
tag = rv.tag

if expand:
shape = tuple(rv_node.op._infer_shape(size, dist_params))
size = shape[: len(shape) - rv_node.op.ndim_supp]
new_size = tuple(new_size) + tuple(size)

# Make sure the new size is a tensor. This dtype-aware conversion helps
# to not unnecessarily pick up a `Cast` in some cases (see #4652).
new_size = at.as_tensor(new_size, ndim=1, dtype="int64")

new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params)
new_rv = new_rv_node.outputs[-1]
new_rv.name = name
for k, v in tag.__dict__.items():
new_rv.tag.__dict__.setdefault(k, v)

# Update "traditional" rng default_update, if that was set for old RV
default_update = getattr(rng, "default_update", None)
if default_update is not None and default_update is rv_node.outputs[0]:
rng.default_update = new_rv_node.outputs[0]

if config.compute_test_value != "off":
compute_test_value(new_rv_node)

return new_rv


def extract_rv_and_value_vars(
var: TensorVariable,
) -> Tuple[TensorVariable, TensorVariable]:
Expand Down
39 changes: 23 additions & 16 deletions pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
from aesara.tensor.random.op import RandomVariable

from pymc.distributions.distribution import SymbolicDistribution, _moment
from pymc.distributions.shape_utils import (
_ndim_supp_dist,
_resize_dist,
ndim_supp_dist,
resize_dist,
)
from pymc.util import check_dist_not_registered


Expand Down Expand Up @@ -65,7 +71,7 @@ def dist(cls, dist, lower, upper, **kwargs):
raise ValueError(
f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}"
)
if dist.owner.op.ndim_supp > 0:
if ndim_supp_dist(dist) > 0:
raise NotImplementedError(
"Censoring of multivariate distributions has not been implemented yet"
)
Expand All @@ -89,26 +95,12 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
rv_out.tag.upper = upper

if size is not None:
rv_out = cls.change_size(rv_out, size)
rv_out = resize_dist(rv_out, size)
if rngs is not None:
rv_out = cls.change_rngs(rv_out, rngs)

return rv_out

@classmethod
def ndim_supp(cls, *dist_params):
return 0

@classmethod
def change_size(cls, rv, new_size, expand=False):
dist_node = rv.tag.dist.owner
lower = rv.tag.lower
upper = rv.tag.upper
rng, old_size, dtype, *dist_params = dist_node.inputs
new_size = new_size if not expand else tuple(new_size) + tuple(old_size)
new_dist = dist_node.op.make_node(rng, new_size, dtype, *dist_params).default_output()
return cls.rv_op(new_dist, lower, upper)

@classmethod
def change_rngs(cls, rv, new_rngs):
(new_rng,) = new_rngs
Expand All @@ -124,6 +116,21 @@ def graph_rvs(cls, rv):
return (rv.tag.dist,)


@_ndim_supp_dist.register(Clip)
def ndim_supp_censored(op, dist):
# We only support Censoring of univariate distributions
return 0


@_resize_dist.register(Clip)
def resize_censored(op, rv, new_size, expand=False):
dist = rv.tag.dist
lower = rv.tag.lower
upper = rv.tag.upper
new_dist = resize_dist(dist, new_size, expand=expand)
return Censored.rv_op(new_dist, lower, upper)


@_moment.register(Clip)
def moment_censored(op, rv, dist, lower, upper):
moment = at.switch(
Expand Down
30 changes: 19 additions & 11 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from aesara.tensor.var import TensorVariable
from typing_extensions import TypeAlias

from pymc.aesaraf import change_rv_size
from pymc.distributions.shape_utils import (
Dims,
Shape,
Expand All @@ -44,6 +43,8 @@
convert_shape,
convert_size,
find_size,
ndim_supp_dist,
resize_dist,
resize_from_dims,
resize_from_observed,
)
Expand Down Expand Up @@ -269,7 +270,7 @@ def __new__(

if resize_shape:
# A batch size was specified through `dims`, or implied by `observed`.
rv_out = change_rv_size(rv=rv_out, new_size=resize_shape, expand=True)
rv_out = resize_dist(dist=rv_out, new_size=resize_shape, expand=True)

rv_out = model.register_rv(
rv_out,
Expand Down Expand Up @@ -355,7 +356,7 @@ def dist(
# Replicate dimensions may be prepended via a shape with Ellipsis as the last element:
if shape is not None and Ellipsis in shape:
replicate_shape = cast(StrongShape, shape[:-1])
rv_out = change_rv_size(rv=rv_out, new_size=replicate_shape, expand=True)
rv_out = resize_dist(dist=rv_out, new_size=replicate_shape, expand=True)

rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
Expand Down Expand Up @@ -399,16 +400,20 @@ def __new__(
cls.rv_op
Returns a TensorVariable that represents the symbolic distribution
parametrized by a default set of parameters and a size and rngs arguments
cls.ndim_supp
Returns the support of the symbolic distribution, given the default
parameters. This may not always be constant, for instance if the symbolic
distribution can be defined based on an arbitrary base distribution.
cls.change_size
Returns an equivalent symbolic distribution with a different size. This is
analogous to `pymc.aesaraf.change_rv_size` for `RandomVariable`s.
cls.graph_rvs
Returns base RVs in a symbolic distribution.

Furthermore, Censored distributions must have a dispatch version of the following
functions for correct behavior in PyMC:
_ndim_supp_dist
Returns the support of the symbolic distribution. This may not always be
constant, for instance if the symbolic distribution can be defined based
on an arbitrary base distribution. This is called by
`pymc.distributions.shape_utils.ndim_supp_dist`
_resize_dist
Returns an equivalent symbolic distribution with a different size. This is
called by `pymc.distrributions.shape_utils.resize_dist`.

Parameters
----------
cls : type
Expand Down Expand Up @@ -559,8 +564,11 @@ def dist(
shape = convert_shape(shape)
size = convert_size(size)

# Create a temporary dist to obtain the ndim_supp
ndim_supp = ndim_supp_dist(cls.rv_op(*dist_params, size=size))
Comment on lines +567 to +568
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not great, but can't think of a less wasteful way


create_size, ndim_expected, ndim_batch, ndim_supp = find_size(
shape=shape, size=size, ndim_supp=cls.ndim_supp(*dist_params)
shape=shape, size=size, ndim_supp=ndim_supp
)
# Create the RV with a `size` right away.
# This is not necessarily the final result.
Expand Down
84 changes: 43 additions & 41 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@
from aesara.tensor import TensorVariable
from aesara.tensor.random.op import RandomVariable

from pymc.aesaraf import change_rv_size
from pymc.distributions import transforms
from pymc.distributions.continuous import Normal, get_tau_sigma
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import SymbolicDistribution, _moment, moment
from pymc.distributions.logprob import logcdf, logp
from pymc.distributions.shape_utils import to_tuple
from pymc.distributions.shape_utils import (
_ndim_supp_dist,
_resize_dist,
ndim_supp_dist,
resize_dist,
to_tuple,
)
from pymc.distributions.transforms import _default_transform
from pymc.util import check_dist_not_registered
from pymc.vartypes import continuous_types, discrete_types
Expand Down Expand Up @@ -188,7 +193,7 @@ def dist(cls, w, comp_dists, **kwargs):
f"Component dist must be a distribution created via the `.dist()` API, got {type(dist)}"
)
check_dist_not_registered(dist)
components_ndim_supp.add(dist.owner.op.ndim_supp)
components_ndim_supp.add(ndim_supp_dist(dist))

if len(components_ndim_supp) > 1:
raise ValueError(
Expand All @@ -209,7 +214,7 @@ def rv_op(cls, weights, *components, size=None, rngs=None):
mix_indexes_rng = aesara.shared(np.random.default_rng())

single_component = len(components) == 1
ndim_supp = components[0].owner.op.ndim_supp
ndim_supp = ndim_supp_dist(components[0])

if size is not None:
components = cls._resize_components(size, *components)
Expand Down Expand Up @@ -294,11 +299,6 @@ def rv_op(cls, weights, *components, size=None, rngs=None):
# inside OpFromGraph and PyMC will never find it otherwise
mix_indexes_rng.default_update = mix_out.owner.outputs[0]

# Reference nodes to facilitate identification in other classmethods
mix_out.tag.weights = weights
mix_out.tag.components = components
mix_out.tag.choices_rng = mix_indexes_rng

# Component RVs terms are accounted by the Mixture logprob, so they can be
# safely ignore by Aeppl (this tag prevents UserWarning)
for component in components:
Expand All @@ -324,45 +324,47 @@ def _resize_components(cls, size, *components):
if len(components) == 1:
# If we have a single component, we need to keep the length of the mixture
# axis intact, because that's what determines the number of mixture components
mix_axis = -components[0].owner.op.ndim_supp - 1
mix_axis = -ndim_supp_dist(components[0]) - 1
mix_size = components[0].shape[mix_axis]
size = tuple(size) + (mix_size,)

return [change_rv_size(component, size) for component in components]

@classmethod
def ndim_supp(cls, weights, *components):
# We already checked that all components have the same support dimensionality
return components[0].owner.op.ndim_supp

@classmethod
def change_size(cls, rv, new_size, expand=False):
weights = rv.tag.weights
components = rv.tag.components
rngs = [component.owner.inputs[0] for component in components] + [rv.tag.choices_rng]

if expand:
component = rv.tag.components[0]
# Old size is equal to `shape[:-ndim_supp]`, with care needed for `ndim_supp == 0`
size_dims = component.ndim - component.owner.op.ndim_supp
if len(rv.tag.components) == 1:
# If we have a single component, new size should ignore the mixture axis
# dimension, as that is not touched by `_resize_components`
size_dims -= 1
old_size = components[0].shape[:size_dims]
new_size = to_tuple(new_size) + tuple(old_size)

components = cls._resize_components(new_size, *components)

return cls.rv_op(weights, *components, rngs=rngs, size=None)
return [resize_dist(component, size) for component in components]

@classmethod
def graph_rvs(cls, rv):
# We return rv, which is itself a pseudo RandomVariable, that contains a
# mix_indexes_ RV in its inner graph. We want super().dist() to generate
# (components + 1) rngs for us, and it will do so based on how many elements
# we return here
return (*rv.tag.components, rv)
return (*rv.owner.inputs[2:], rv)


@_ndim_supp_dist.register(MarginalMixtureRV)
def ndim_supp_marginal_mixture(op, rv):
# We already checked that all components have the same support dimensionality
components = rv.owner.inputs[2:]
return ndim_supp_dist(components[0])


@_resize_dist.register(MarginalMixtureRV)
def resize_marginal_mixture(op, rv, new_size, expand=False):
mix_indexes_rng, weights, *components = rv.owner.inputs
rngs = [component.owner.inputs[0] for component in components] + [mix_indexes_rng]

if expand:
component = components[0]
# Old size is equal to `shape[:-ndim_supp]`, with care needed for `ndim_supp == 0`
size_dims = component.ndim - ndim_supp_dist(component)
if len(components) == 1:
# If we have a single component, new size should ignore the mixture axis
# dimension, as that is not touched by `_resize_components`
size_dims -= 1
old_size = components[0].shape[:size_dims]
new_size = to_tuple(new_size) + tuple(old_size)

components = Mixture._resize_components(new_size, *components)

return Mixture.rv_op(weights, *components, rngs=rngs, size=None)


@_get_measurable_outputs.register(MarginalMixtureRV)
Expand All @@ -378,7 +380,7 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs):
# single component
if len(components) == 1:
# Need to broadcast value across mixture axis
mix_axis = -components[0].owner.op.ndim_supp - 1
mix_axis = -ndim_supp_dist(components[0]) - 1
components_logp = logp(components[0], at.expand_dims(value, mix_axis))
else:
components_logp = at.stack(
Expand Down Expand Up @@ -411,7 +413,7 @@ def marginal_mixture_logcdf(op, value, rng, weights, *components, **kwargs):
# single component
if len(components) == 1:
# Need to broadcast value across mixture axis
mix_axis = -components[0].owner.op.ndim_supp - 1
mix_axis = -ndim_supp_dist(components[0]) - 1
components_logcdf = logcdf(components[0], at.expand_dims(value, mix_axis))
else:
components_logcdf = at.stack(
Expand Down Expand Up @@ -440,7 +442,7 @@ def marginal_mixture_logcdf(op, value, rng, weights, *components, **kwargs):

@_moment.register(MarginalMixtureRV)
def marginal_mixture_moment(op, rv, rng, weights, *components):
ndim_supp = components[0].owner.op.ndim_supp
ndim_supp = ndim_supp_dist(components[0])
weights = at.shape_padright(weights, ndim_supp)
mix_axis = -ndim_supp - 1

Expand Down
Loading