Skip to content

Fix broadcasting via observed and dims #6063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 72 additions & 78 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from abc import ABCMeta
from functools import singledispatch
from typing import Callable, Optional, Sequence, Tuple, Union, cast
from typing import Callable, Optional, Sequence, Tuple, Union

import aesara
import numpy as np
Expand All @@ -33,19 +33,18 @@
from aesara.tensor.var import TensorVariable
from typing_extensions import TypeAlias

from pymc.aesaraf import change_rv_size
from pymc.aesaraf import change_rv_size, convert_observed_data
from pymc.distributions.shape_utils import (
Dims,
Shape,
Size,
StrongDims,
StrongShape,
WeakDims,
convert_dims,
convert_shape,
convert_size,
find_size,
resize_from_dims,
resize_from_observed,
shape_from_dims,
)
from pymc.printing import str_for_dist, str_for_symbolic_dist
from pymc.util import UNSET
Expand Down Expand Up @@ -150,35 +149,33 @@ def fn(*args, **kwargs):
return fn


def _make_rv_and_resize_shape(
def _make_rv_and_resize_shape_from_dims(
*,
cls,
dims: Optional[Dims],
dims: Optional[StrongDims],
model,
observed,
args,
**kwargs,
) -> Tuple[Variable, Optional[WeakDims], Optional[Union[np.ndarray, Variable]], StrongShape]:
"""Creates the RV and processes dims or observed to determine a resize shape."""
# Create the RV without dims information, because that's not something tracked at the Aesara level.
# If necessary we'll later replicate to a different size implied by already known dims.
) -> Tuple[Variable, StrongShape]:
"""Creates the RV, possibly using dims or observed to determine a resize shape (if needed)."""
resize_shape_from_dims = None
size_or_shape = kwargs.get("size") or kwargs.get("shape")

# Preference is given to size or shape. If not specified, we rely on dims and
# finally, observed, to determine the shape of the variable. Because dims can be
# specified on the fly, we need a two-step process where we first create the RV
# without dims information and then resize it.
if not size_or_shape and observed is not None:
kwargs["shape"] = tuple(observed.shape)

# Create the RV without dims information
rv_out = cls.dist(*args, **kwargs)
ndim_actual = rv_out.ndim
resize_shape = None

# # `dims` are only available with this API, because `.dist()` can be used
# # without a modelcontext and dims are not tracked at the Aesara level.
dims = convert_dims(dims)
dims_can_resize = kwargs.get("shape", None) is None and kwargs.get("size", None) is None
if dims is not None:
if dims_can_resize:
resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
elif Ellipsis in dims:
# Replace ... with None entries to match the actual dimensionality.
dims = (*dims[:-1], *[None] * ndim_actual)[:ndim_actual]
elif observed is not None:
resize_shape, observed = resize_from_observed(observed, ndim_actual)
return rv_out, dims, observed, resize_shape

if not size_or_shape and dims is not None:
resize_shape_from_dims = shape_from_dims(dims, tuple(rv_out.shape), model)

return rv_out, resize_shape_from_dims


class Distribution(metaclass=DistributionMeta):
Expand Down Expand Up @@ -212,14 +209,17 @@ def __new__(
rng : optional
Random number generator to use with the RandomVariable.
dims : tuple, optional
A tuple of dimension names known to the model.
A tuple of dimension names known to the model. When shape is not provided,
the shape of dims is used to define the shape of the variable.
initval : optional
Numeric or symbolic untransformed initial value of matching shape,
or one of the following initial value strategies: "moment", "prior".
Depending on the sampler's settings, a random jitter may be added to numeric, symbolic
or moment-based initial values in the transformed space.
observed : optional
Observed data to be passed when registering the random variable in the model.
When neither shape nor dims is provided, the shape of observed is used to
define the shape of the variable.
See ``Model.register_rv``.
total_size : float, optional
See ``Model.register_rv``.
Expand Down Expand Up @@ -258,15 +258,21 @@ def __new__(
if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

# Create the RV and process dims and observed to determine
# a shape by which the created RV may need to be resized.
rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
dims = convert_dims(dims)
if observed is not None:
observed = convert_observed_data(observed)

# Create the RV, without taking `dims` into consideration
rv_out, resize_shape_from_dims = _make_rv_and_resize_shape_from_dims(
cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs
)

if resize_shape:
# A batch size was specified through `dims`, or implied by `observed`.
rv_out = change_rv_size(rv=rv_out, new_size=resize_shape, expand=True)
# Resize variable based on `dims` information
if resize_shape_from_dims:
resize_size_from_dims = find_size(
shape=resize_shape_from_dims, size=None, ndim_supp=cls.rv_op.ndim_supp
)
rv_out = change_rv_size(rv=rv_out, new_size=resize_size_from_dims, expand=False)

rv_out = model.register_rv(
rv_out,
Expand All @@ -286,7 +292,7 @@ def __new__(

rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
return rv_out

@classmethod
Expand All @@ -305,9 +311,6 @@ def dist(
The inputs to the `RandomVariable` `Op`.
shape : int, tuple, Variable, optional
A tuple of sizes for each dimension of the new RV.

An Ellipsis (...) may be inserted in the last position to short-hand refer to
all the dimensions that the RV would get if no shape/size/dims were passed at all.
**kwargs
Keyword arguments that will be forwarded to the Aesara RV Op.
Most prominently: ``size`` or ``dtype``.
Expand Down Expand Up @@ -343,21 +346,12 @@ def dist(
shape = convert_shape(shape)
size = convert_size(size)

create_size, ndim_expected, ndim_batch, ndim_supp = find_size(
shape=shape, size=size, ndim_supp=cls.rv_op.ndim_supp
)
# Create the RV with a `size` right away.
# This is not necessarily the final result.
create_size = find_size(shape=shape, size=size, ndim_supp=cls.rv_op.ndim_supp)
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)

# Replicate dimensions may be prepended via a shape with Ellipsis as the last element:
if shape is not None and Ellipsis in shape:
replicate_shape = cast(StrongShape, shape[:-1])
rv_out = change_rv_size(rv=rv_out, new_size=replicate_shape, expand=True)

rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
return rv_out


Expand Down Expand Up @@ -414,14 +408,17 @@ def __new__(
name : str
Name for the new model variable.
dims : tuple, optional
A tuple of dimension names known to the model.
A tuple of dimension names known to the model. When shape is not provided,
the shape of dims is used to define the shape of the variable.
initval : optional
Numeric or symbolic untransformed initial value of matching shape,
or one of the following initial value strategies: "moment", "prior".
Depending on the sampler's settings, a random jitter may be added to numeric,
symbolic or moment-based initial values in the transformed space.
observed : optional
Observed data to be passed when registering the random variable in the model.
When neither shape nor dims is provided, the shape of observed is used to
define the shape of the variable.
See ``Model.register_rv``.
total_size : float, optional
See ``Model.register_rv``.
Expand Down Expand Up @@ -460,19 +457,21 @@ def __new__(
if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

# Create the RV and process dims and observed to determine
# a shape by which the created RV may need to be resized.
rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
dims = convert_dims(dims)
if observed is not None:
observed = convert_observed_data(observed)

# Create the RV, without taking `dims` into consideration
rv_out, resize_shape_from_dims = _make_rv_and_resize_shape_from_dims(
cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs
)

if resize_shape:
# A batch size was specified through `dims`, or implied by `observed`.
rv_out = cls.change_size(
rv=rv_out,
new_size=resize_shape,
expand=True,
# Resize variable based on `dims` information
if resize_shape_from_dims:
resize_size_from_dims = find_size(
shape=resize_shape_from_dims, size=None, ndim_supp=rv_out.tag.ndim_supp
)
rv_out = cls.change_size(rv=rv_out, new_size=resize_size_from_dims, expand=False)

rv_out = model.register_rv(
rv_out,
Expand All @@ -489,6 +488,10 @@ def __new__(
functools.partial(str_for_symbolic_dist, formatting="latex"), rv_out
)

rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")

return rv_out

@classmethod
Expand All @@ -508,8 +511,6 @@ def dist(
The inputs to the `RandomVariable` `Op`.
shape : int, tuple, Variable, optional
A tuple of sizes for each dimension of the new RV.
An Ellipsis (...) may be inserted in the last position to short-hand refer to
all the dimensions that the RV would get if no shape/size/dims were passed at all.
size : int, tuple, Variable, optional
For creating the RV like in Aesara/NumPy.

Expand Down Expand Up @@ -543,23 +544,16 @@ def dist(
shape = convert_shape(shape)
size = convert_size(size)

create_size, ndim_expected, ndim_batch, ndim_supp = find_size(
shape=shape, size=size, ndim_supp=cls.ndim_supp(*dist_params)
)
# Create the RV with a `size` right away.
# This is not necessarily the final result.
graph = cls.rv_op(*dist_params, size=create_size, **kwargs)

# Replicate dimensions may be prepended via a shape with Ellipsis as the last element:
if shape is not None and Ellipsis in shape:
replicate_shape = cast(StrongShape, shape[:-1])
graph = cls.change_size(rv=graph, new_size=replicate_shape, expand=True)

# TODO: Create new attr error stating that these are not available for DerivedDistribution
# rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
# rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
# rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
return graph
ndim_supp = cls.ndim_supp(*dist_params)
create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp)
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
# This is needed for resizing from dims in `__new__`
rv_out.tag.ndim_supp = ndim_supp

rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
return rv_out


@singledispatch
Expand Down
Loading