Skip to content

Use separate argument in CustomDist for functions that return symbolic representations #6462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 65 additions & 44 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def dist(
cls,
*dist_params,
class_name: str,
random: Callable,
dist: Callable,
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
moment: Optional[Callable] = None,
Expand All @@ -622,7 +622,7 @@ def dist(
**kwargs,
):
warnings.warn(
"CustomDist with symbolic random graph is still experimental. Expect bugs!",
"CustomDist with dist function is still experimental. Expect bugs!",
UserWarning,
)

Expand All @@ -644,7 +644,7 @@ def dist(
class_name=class_name,
logp=logp,
logcdf=logcdf,
random=random,
dist=dist,
moment=moment,
ndim_supp=ndim_supp,
**kwargs,
Expand All @@ -655,7 +655,7 @@ def rv_op(
cls,
*dist_params,
class_name: str,
random: Callable,
dist: Callable,
logp: Optional[Callable],
logcdf: Optional[Callable],
moment: Optional[Callable],
Expand All @@ -666,16 +666,16 @@ def rv_op(
dummy_size_param = size.type()
dummy_dist_params = [dist_param.type() for dist_param in dist_params]
with BlockModelAccess(
error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API"
error_msg_on_access="Model variables cannot be created in the dist function. Use the `.dist` API"
):
dummy_rv = random(*dummy_dist_params, dummy_size_param)
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
dummy_params = [dummy_size_param] + dummy_dist_params
dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,))

rv_type = type(
f"CustomSymbolicDistRV_{class_name}",
(CustomSymbolicDistRV,),
# If logp is not provided, we infer it from the random graph
# If logp is not provided, we try to infer it from the dist graph
dict(
inline_logprob=logp is None,
),
Expand All @@ -697,11 +697,11 @@ def custom_dist_get_moment(op, rv, size, *params):
return moment(rv, size, *params[: len(params)])

@_change_dist_size.register(rv_type)
def change_custom_symbolic_dist_size(op, dist, new_size, expand):
node = dist.owner
def change_custom_symbolic_dist_size(op, rv, new_size, expand):
node = rv.owner

if expand:
shape = tuple(dist.shape)
shape = tuple(rv.shape)
old_size = shape[: len(shape) - node.op.ndim_supp]
new_size = tuple(new_size) + tuple(old_size)
new_size = at.as_tensor(new_size, ndim=1, dtype="int64")
Expand All @@ -711,7 +711,7 @@ def change_custom_symbolic_dist_size(op, dist, new_size, expand):
# OpFromGraph has to be recreated if the size type changes!
dummy_size_param = new_size.type()
dummy_dist_params = [dist_param.type() for dist_param in old_dist_params]
dummy_rv = random(*dummy_dist_params, dummy_size_param)
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
dummy_params = [dummy_size_param] + dummy_dist_params
dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,))
new_rv_op = rv_type(
Expand All @@ -737,17 +737,18 @@ class CustomDist:
This class can be used to wrap black-box random and logp methods for use in
forward and mcmc sampling.

A user can provide a `random` function that returns numerical draws (e.g., via
NumPy routines) or an Aesara graph that represents the random graph when evaluated.
A user can provide a `dist` function that returns a PyTensor graph built from
simpler PyMC distributions, which represents the distribution. This graph is
used to take random draws, and to infer the logp expression automatically
when not provided by the user.

A user can provide a `logp` function that must return an Aesara graph that
represents the logp graph when evaluated. This is used for mcmc sampling. In some
cases, if a user provides a `random` function that returns an Aesara graph, PyMC
will be able to automatically derive the appropriate `logp` graph when performing
MCMC sampling.
Alternatively, a user can provide a `random` function that returns numerical
draws (e.g., via NumPy routines), and a `logp` function that must return an
Python graph that represents the logp graph when evaluated. This is used for
mcmc sampling.

Additionally, a user can provide a `logcdf` and `moment` functions that must return
an Aesara graph that computes those quantities. These may be used by other PyMC
an PyTensor graph that computes those quantities. These may be used by other PyMC
routines.

Parameters
Expand All @@ -765,11 +766,20 @@ class CustomDist:
different methods across separate models, be sure to use distinct
class_names.

dist: Optional[Callable]
A callable that returns a PyTensor graph built from simpler PyMC distributions
which represents the distribution. This can be used by PyMC to take random draws
as well as to infer the logp of the distribution in some cases. In that case
it's not necessary to implement ``random`` or ``logp`` functions.

It must have the following signature: ``dist(*dist_params, size)``.
The symbolic tensor distribution parameters are passed as positional arguments in
the same order as they are supplied when the ``CustomDist`` is constructed.

random : Optional[Callable]
A callable that can be used to 1) generate random draws from the distribution or
2) returns an Aesara graph that represents the random draws.
A callable that can be used to generate random draws from the distribution

If 1) it must have the following signature: ``random(*dist_params, rng=None, size=None)``.
It must have the following signature: ``random(*dist_params, rng=None, size=None)``.
The numerical distribution parameters are passed as positional arguments in the
same order as they are supplied when the ``CustomDist`` is constructed.
The keyword arguments are ``rng``, which will provide the random variable's
Expand All @@ -778,9 +788,6 @@ class CustomDist:
error will be raised when trying to draw random samples from the distribution's
prior or posterior predictive.

If 2) it must have the following signature: ``random(*dist_params, size)``.
The symbolic tensor distribution parameters are passed as postional arguments in
the same order as they are supplied when the ``CustomDist`` is constructed.
logp : Optional[Callable]
A callable that calculates the log probability of some given ``value``
conditioned on certain distribution parameter values. It must have the
Expand All @@ -789,8 +796,8 @@ class CustomDist:
are the tensors that hold the values of the distribution parameters.
This function must return an PyTensor tensor.

When the `random` function is specified and returns an `Aesara` graph, PyMC
will try to automatically infer the `logp` when this is not provided.
When the `dist` function is specified, PyMC will try to automatically
infer the `logp` when this is not provided.

Otherwise, a ``NotImplementedError`` will be raised when trying to compute the
distribution's logp.
Expand Down Expand Up @@ -818,11 +825,11 @@ class CustomDist:
The list of number of dimensions in the support of each of the distribution's
parameters. If ``None``, it is assumed that all parameters are scalars, hence
the number of dimensions of their support will be 0. This is not needed if an
Aesara random function is provided
PyTensor dist function is provided.
dtype : str
The dtype of the distribution. All draws and observations passed into the
distribution will be cast onto this dtype. This is not needed if an Aesara
random function is provided, which should already return the right dtype!
distribution will be cast onto this dtype. This is not needed if an PyTensor
dist function is provided, which should already return the right dtype!
kwargs :
Extra keyword arguments are passed to the parent's class ``__new__`` method.

Expand Down Expand Up @@ -884,16 +891,16 @@ def random(
)
prior = pm.sample_prior_predictive(10)

Provide a random function that creates an Aesara random graph. PyMC can
automatically infer that the logp of this variable corresponds to a shifted
Exponential distribution.
Provide a dist function that creates a PyTensor graph built from other
PyMC distributions. PyMC can automatically infer that the logp of this
variable corresponds to a shifted Exponential distribution.

.. code-block:: python

import pymc as pm
from pytensor.tensor import TensorVariable

def random(
def dist(
lam: TensorVariable,
shift: TensorVariable,
size: TensorVariable,
Expand All @@ -907,16 +914,16 @@ def random(
"custom_dist",
lam,
shift,
random=random,
dist=dist,
observed=[-1, -1, 0],
)

prior = pm.sample_prior_predictive()
posterior = pm.sample()

Provide a random function that creates an Aesara random graph. PyMC can
automatically infer that the logp of this variable corresponds to a
modified-PERT distribution.
Provide a dist function that creates a PyTensor graph built from other
PyMC distributions. PyMC can automatically infer that the logp of
this variable corresponds to a modified-PERT distribution.

.. code-block:: python

Expand All @@ -940,7 +947,7 @@ def pert(
peak = pm.Normal("peak", 50, 10)
high = pm.Normal("high", 100, 10)
lmbda = 4
pm.CustomDist("pert", low, peak, high, lmbda, random=pert, observed=[30, 35, 73])
pm.CustomDist("pert", low, peak, high, lmbda, dist=pert, observed=[30, 35, 73])

m.point_logps()

Expand All @@ -950,6 +957,7 @@ def __new__(
cls,
name,
*dist_params,
dist: Optional[Callable] = None,
random: Optional[Callable] = None,
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
Expand All @@ -968,12 +976,13 @@ def __new__(
"parameters are positional arguments."
)
dist_params = cls.parse_dist_params(dist_params)
if cls.is_symbolic_random(random, dist_params):
cls.check_valid_dist_random(dist, random, dist_params)
if dist is not None:
return _CustomSymbolicDist(
name,
*dist_params,
class_name=name,
random=random,
dist=dist,
logp=logp,
logcdf=logcdf,
moment=moment,
Expand Down Expand Up @@ -1001,6 +1010,7 @@ def dist(
cls,
*dist_params,
class_name: str,
dist: Optional[Callable] = None,
random: Optional[Callable] = None,
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
Expand All @@ -1011,11 +1021,12 @@ def dist(
**kwargs,
):
dist_params = cls.parse_dist_params(dist_params)
if cls.is_symbolic_random(random, dist_params):
cls.check_valid_dist_random(dist, random, dist_params)
if dist is not None:
return _CustomSymbolicDist.dist(
*dist_params,
class_name=class_name,
random=random,
dist=dist,
logp=logp,
logcdf=logcdf,
moment=moment,
Expand Down Expand Up @@ -1048,6 +1059,16 @@ def parse_dist_params(cls, dist_params):
)
return [as_tensor_variable(param) for param in dist_params]

@classmethod
def check_valid_dist_random(cls, dist, random, dist_params):
if dist is not None and random is not None:
raise ValueError("Cannot provide both dist and random functions")
if random is not None and cls.is_symbolic_random(random, dist_params):
raise TypeError(
"API change: function passed to `random` argument should no longer return a PyTensor graph. "
"Pass such function to the `dist` argument instead."
)

@classmethod
def is_symbolic_random(self, random, dist_params):
if random is None:
Expand All @@ -1056,7 +1077,7 @@ def is_symbolic_random(self, random, dist_params):
try:
size = normalize_size_param(None)
with BlockModelAccess(
error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API"
error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API to create such variables."
):
out = random(*dist_params, size)
except BlockModelAccessError:
Expand Down
33 changes: 22 additions & 11 deletions pymc/tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def test_dist(self):

class TestCustomSymbolicDist:
def test_basic(self):
def custom_random(mu, sigma, size):
def custom_dist(mu, sigma, size):
return at.exp(pm.Normal.dist(mu, sigma, size=size))

with Model() as m:
Expand All @@ -379,7 +379,7 @@ def custom_random(mu, sigma, size):
"lognormal",
mu,
sigma,
random=custom_random,
dist=custom_dist,
size=(10,),
transform=log,
initval=np.ones(10),
Expand All @@ -401,7 +401,7 @@ def custom_random(mu, sigma, size):
np.testing.assert_allclose(m.compile_logp()(ip), ref_m.compile_logp()(ip))

def test_random_multiple_rngs(self):
def custom_random(p, sigma, size):
def custom_dist(p, sigma, size):
idx = pm.Bernoulli.dist(p=p)
comps = pm.Normal.dist([-sigma, sigma], 1e-1, size=(*size, 2)).T
return comps[idx]
Expand All @@ -411,7 +411,7 @@ def custom_random(p, sigma, size):
0.5,
10.0,
class_name="customdist",
random=custom_random,
dist=custom_dist,
size=(10,),
)

Expand All @@ -426,7 +426,7 @@ def custom_random(p, sigma, size):
assert np.unique(draws).size == 20

def test_custom_methods(self):
def custom_random(mu, size):
def custom_dist(mu, size):
if rv_size_is_none(size):
return mu
return at.full(size, mu)
Expand All @@ -444,7 +444,7 @@ def custom_logcdf(value, mu):
customdist = CustomDist.dist(
[np.e, np.e],
class_name="customdist",
random=custom_random,
dist=custom_dist,
moment=custom_moment,
logp=custom_logp,
logcdf=custom_logcdf,
Expand All @@ -458,15 +458,15 @@ def custom_logcdf(value, mu):
np.testing.assert_allclose(logcdf(customdist, [0, 0]).eval(), [np.e + 3, np.e + 3])

def test_change_size(self):
def custom_random(mu, sigma, size):
def custom_dist(mu, sigma, size):
return at.exp(pm.Normal.dist(mu, sigma, size=size))

with pytest.warns(UserWarning, match="experimental"):
lognormal = CustomDist.dist(
0,
1,
class_name="lognormal",
random=custom_random,
dist=custom_dist,
size=(10,),
)
assert isinstance(lognormal.owner.op, CustomSymbolicDistRV)
Expand All @@ -481,15 +481,26 @@ def custom_random(mu, sigma, size):
assert tuple(new_lognormal.shape.eval()) == (2, 5, 10)

def test_error_model_access(self):
def random(size):
def custom_dist(size):
return pm.Flat("Flat", size=size)

with pm.Model() as m:
with pytest.raises(
BlockModelAccessError,
match="Model variables cannot be created in the random function",
match="Model variables cannot be created in the dist function",
):
CustomDist("custom_dist", random=random)
CustomDist("custom_dist", dist=custom_dist)

def test_api_change_error(self):
def old_random(size):
return pm.Flat.dist(size=size)

# Old API raises
with pytest.raises(TypeError, match="API change: function passed to `random` argument"):
pm.CustomDist.dist(random=old_random, class_name="custom_dist")

# New API is fine
pm.CustomDist.dist(dist=old_random, class_name="custom_dist")


class TestSymbolicRandomVariable:
Expand Down