diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index a483017b7c..955a800019 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -613,7 +613,7 @@ def dist( cls, *dist_params, class_name: str, - random: Callable, + dist: Callable, logp: Optional[Callable] = None, logcdf: Optional[Callable] = None, moment: Optional[Callable] = None, @@ -622,7 +622,7 @@ def dist( **kwargs, ): warnings.warn( - "CustomDist with symbolic random graph is still experimental. Expect bugs!", + "CustomDist with dist function is still experimental. Expect bugs!", UserWarning, ) @@ -644,7 +644,7 @@ def dist( class_name=class_name, logp=logp, logcdf=logcdf, - random=random, + dist=dist, moment=moment, ndim_supp=ndim_supp, **kwargs, @@ -655,7 +655,7 @@ def rv_op( cls, *dist_params, class_name: str, - random: Callable, + dist: Callable, logp: Optional[Callable], logcdf: Optional[Callable], moment: Optional[Callable], @@ -666,16 +666,16 @@ def rv_op( dummy_size_param = size.type() dummy_dist_params = [dist_param.type() for dist_param in dist_params] with BlockModelAccess( - error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API" + error_msg_on_access="Model variables cannot be created in the dist function. Use the `.dist` API" ): - dummy_rv = random(*dummy_dist_params, dummy_size_param) + dummy_rv = dist(*dummy_dist_params, dummy_size_param) dummy_params = [dummy_size_param] + dummy_dist_params dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,)) rv_type = type( f"CustomSymbolicDistRV_{class_name}", (CustomSymbolicDistRV,), - # If logp is not provided, we infer it from the random graph + # If logp is not provided, we try to infer it from the dist graph dict( inline_logprob=logp is None, ), @@ -697,11 +697,11 @@ def custom_dist_get_moment(op, rv, size, *params): return moment(rv, size, *params[: len(params)]) @_change_dist_size.register(rv_type) - def change_custom_symbolic_dist_size(op, dist, new_size, expand): - node = dist.owner + def change_custom_symbolic_dist_size(op, rv, new_size, expand): + node = rv.owner if expand: - shape = tuple(dist.shape) + shape = tuple(rv.shape) old_size = shape[: len(shape) - node.op.ndim_supp] new_size = tuple(new_size) + tuple(old_size) new_size = at.as_tensor(new_size, ndim=1, dtype="int64") @@ -711,7 +711,7 @@ def change_custom_symbolic_dist_size(op, dist, new_size, expand): # OpFromGraph has to be recreated if the size type changes! dummy_size_param = new_size.type() dummy_dist_params = [dist_param.type() for dist_param in old_dist_params] - dummy_rv = random(*dummy_dist_params, dummy_size_param) + dummy_rv = dist(*dummy_dist_params, dummy_size_param) dummy_params = [dummy_size_param] + dummy_dist_params dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,)) new_rv_op = rv_type( @@ -737,17 +737,18 @@ class CustomDist: This class can be used to wrap black-box random and logp methods for use in forward and mcmc sampling. - A user can provide a `random` function that returns numerical draws (e.g., via - NumPy routines) or an Aesara graph that represents the random graph when evaluated. + A user can provide a `dist` function that returns a PyTensor graph built from + simpler PyMC distributions, which represents the distribution. This graph is + used to take random draws, and to infer the logp expression automatically + when not provided by the user. - A user can provide a `logp` function that must return an Aesara graph that - represents the logp graph when evaluated. This is used for mcmc sampling. In some - cases, if a user provides a `random` function that returns an Aesara graph, PyMC - will be able to automatically derive the appropriate `logp` graph when performing - MCMC sampling. + Alternatively, a user can provide a `random` function that returns numerical + draws (e.g., via NumPy routines), and a `logp` function that must return an + Python graph that represents the logp graph when evaluated. This is used for + mcmc sampling. Additionally, a user can provide a `logcdf` and `moment` functions that must return - an Aesara graph that computes those quantities. These may be used by other PyMC + an PyTensor graph that computes those quantities. These may be used by other PyMC routines. Parameters @@ -765,11 +766,20 @@ class CustomDist: different methods across separate models, be sure to use distinct class_names. + dist: Optional[Callable] + A callable that returns a PyTensor graph built from simpler PyMC distributions + which represents the distribution. This can be used by PyMC to take random draws + as well as to infer the logp of the distribution in some cases. In that case + it's not necessary to implement ``random`` or ``logp`` functions. + + It must have the following signature: ``dist(*dist_params, size)``. + The symbolic tensor distribution parameters are passed as positional arguments in + the same order as they are supplied when the ``CustomDist`` is constructed. + random : Optional[Callable] - A callable that can be used to 1) generate random draws from the distribution or - 2) returns an Aesara graph that represents the random draws. + A callable that can be used to generate random draws from the distribution - If 1) it must have the following signature: ``random(*dist_params, rng=None, size=None)``. + It must have the following signature: ``random(*dist_params, rng=None, size=None)``. The numerical distribution parameters are passed as positional arguments in the same order as they are supplied when the ``CustomDist`` is constructed. The keyword arguments are ``rng``, which will provide the random variable's @@ -778,9 +788,6 @@ class CustomDist: error will be raised when trying to draw random samples from the distribution's prior or posterior predictive. - If 2) it must have the following signature: ``random(*dist_params, size)``. - The symbolic tensor distribution parameters are passed as postional arguments in - the same order as they are supplied when the ``CustomDist`` is constructed. logp : Optional[Callable] A callable that calculates the log probability of some given ``value`` conditioned on certain distribution parameter values. It must have the @@ -789,8 +796,8 @@ class CustomDist: are the tensors that hold the values of the distribution parameters. This function must return an PyTensor tensor. - When the `random` function is specified and returns an `Aesara` graph, PyMC - will try to automatically infer the `logp` when this is not provided. + When the `dist` function is specified, PyMC will try to automatically + infer the `logp` when this is not provided. Otherwise, a ``NotImplementedError`` will be raised when trying to compute the distribution's logp. @@ -818,11 +825,11 @@ class CustomDist: The list of number of dimensions in the support of each of the distribution's parameters. If ``None``, it is assumed that all parameters are scalars, hence the number of dimensions of their support will be 0. This is not needed if an - Aesara random function is provided + PyTensor dist function is provided. dtype : str The dtype of the distribution. All draws and observations passed into the - distribution will be cast onto this dtype. This is not needed if an Aesara - random function is provided, which should already return the right dtype! + distribution will be cast onto this dtype. This is not needed if an PyTensor + dist function is provided, which should already return the right dtype! kwargs : Extra keyword arguments are passed to the parent's class ``__new__`` method. @@ -884,16 +891,16 @@ def random( ) prior = pm.sample_prior_predictive(10) - Provide a random function that creates an Aesara random graph. PyMC can - automatically infer that the logp of this variable corresponds to a shifted - Exponential distribution. + Provide a dist function that creates a PyTensor graph built from other + PyMC distributions. PyMC can automatically infer that the logp of this + variable corresponds to a shifted Exponential distribution. .. code-block:: python import pymc as pm from pytensor.tensor import TensorVariable - def random( + def dist( lam: TensorVariable, shift: TensorVariable, size: TensorVariable, @@ -907,16 +914,16 @@ def random( "custom_dist", lam, shift, - random=random, + dist=dist, observed=[-1, -1, 0], ) prior = pm.sample_prior_predictive() posterior = pm.sample() - Provide a random function that creates an Aesara random graph. PyMC can - automatically infer that the logp of this variable corresponds to a - modified-PERT distribution. + Provide a dist function that creates a PyTensor graph built from other + PyMC distributions. PyMC can automatically infer that the logp of + this variable corresponds to a modified-PERT distribution. .. code-block:: python @@ -940,7 +947,7 @@ def pert( peak = pm.Normal("peak", 50, 10) high = pm.Normal("high", 100, 10) lmbda = 4 - pm.CustomDist("pert", low, peak, high, lmbda, random=pert, observed=[30, 35, 73]) + pm.CustomDist("pert", low, peak, high, lmbda, dist=pert, observed=[30, 35, 73]) m.point_logps() @@ -950,6 +957,7 @@ def __new__( cls, name, *dist_params, + dist: Optional[Callable] = None, random: Optional[Callable] = None, logp: Optional[Callable] = None, logcdf: Optional[Callable] = None, @@ -968,12 +976,13 @@ def __new__( "parameters are positional arguments." ) dist_params = cls.parse_dist_params(dist_params) - if cls.is_symbolic_random(random, dist_params): + cls.check_valid_dist_random(dist, random, dist_params) + if dist is not None: return _CustomSymbolicDist( name, *dist_params, class_name=name, - random=random, + dist=dist, logp=logp, logcdf=logcdf, moment=moment, @@ -1001,6 +1010,7 @@ def dist( cls, *dist_params, class_name: str, + dist: Optional[Callable] = None, random: Optional[Callable] = None, logp: Optional[Callable] = None, logcdf: Optional[Callable] = None, @@ -1011,11 +1021,12 @@ def dist( **kwargs, ): dist_params = cls.parse_dist_params(dist_params) - if cls.is_symbolic_random(random, dist_params): + cls.check_valid_dist_random(dist, random, dist_params) + if dist is not None: return _CustomSymbolicDist.dist( *dist_params, class_name=class_name, - random=random, + dist=dist, logp=logp, logcdf=logcdf, moment=moment, @@ -1048,6 +1059,16 @@ def parse_dist_params(cls, dist_params): ) return [as_tensor_variable(param) for param in dist_params] + @classmethod + def check_valid_dist_random(cls, dist, random, dist_params): + if dist is not None and random is not None: + raise ValueError("Cannot provide both dist and random functions") + if random is not None and cls.is_symbolic_random(random, dist_params): + raise TypeError( + "API change: function passed to `random` argument should no longer return a PyTensor graph. " + "Pass such function to the `dist` argument instead." + ) + @classmethod def is_symbolic_random(self, random, dist_params): if random is None: @@ -1056,7 +1077,7 @@ def is_symbolic_random(self, random, dist_params): try: size = normalize_size_param(None) with BlockModelAccess( - error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API" + error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API to create such variables." ): out = random(*dist_params, size) except BlockModelAccessError: diff --git a/pymc/tests/distributions/test_distribution.py b/pymc/tests/distributions/test_distribution.py index d4d441096d..e353bb9ff6 100644 --- a/pymc/tests/distributions/test_distribution.py +++ b/pymc/tests/distributions/test_distribution.py @@ -368,7 +368,7 @@ def test_dist(self): class TestCustomSymbolicDist: def test_basic(self): - def custom_random(mu, sigma, size): + def custom_dist(mu, sigma, size): return at.exp(pm.Normal.dist(mu, sigma, size=size)) with Model() as m: @@ -379,7 +379,7 @@ def custom_random(mu, sigma, size): "lognormal", mu, sigma, - random=custom_random, + dist=custom_dist, size=(10,), transform=log, initval=np.ones(10), @@ -401,7 +401,7 @@ def custom_random(mu, sigma, size): np.testing.assert_allclose(m.compile_logp()(ip), ref_m.compile_logp()(ip)) def test_random_multiple_rngs(self): - def custom_random(p, sigma, size): + def custom_dist(p, sigma, size): idx = pm.Bernoulli.dist(p=p) comps = pm.Normal.dist([-sigma, sigma], 1e-1, size=(*size, 2)).T return comps[idx] @@ -411,7 +411,7 @@ def custom_random(p, sigma, size): 0.5, 10.0, class_name="customdist", - random=custom_random, + dist=custom_dist, size=(10,), ) @@ -426,7 +426,7 @@ def custom_random(p, sigma, size): assert np.unique(draws).size == 20 def test_custom_methods(self): - def custom_random(mu, size): + def custom_dist(mu, size): if rv_size_is_none(size): return mu return at.full(size, mu) @@ -444,7 +444,7 @@ def custom_logcdf(value, mu): customdist = CustomDist.dist( [np.e, np.e], class_name="customdist", - random=custom_random, + dist=custom_dist, moment=custom_moment, logp=custom_logp, logcdf=custom_logcdf, @@ -458,7 +458,7 @@ def custom_logcdf(value, mu): np.testing.assert_allclose(logcdf(customdist, [0, 0]).eval(), [np.e + 3, np.e + 3]) def test_change_size(self): - def custom_random(mu, sigma, size): + def custom_dist(mu, sigma, size): return at.exp(pm.Normal.dist(mu, sigma, size=size)) with pytest.warns(UserWarning, match="experimental"): @@ -466,7 +466,7 @@ def custom_random(mu, sigma, size): 0, 1, class_name="lognormal", - random=custom_random, + dist=custom_dist, size=(10,), ) assert isinstance(lognormal.owner.op, CustomSymbolicDistRV) @@ -481,15 +481,26 @@ def custom_random(mu, sigma, size): assert tuple(new_lognormal.shape.eval()) == (2, 5, 10) def test_error_model_access(self): - def random(size): + def custom_dist(size): return pm.Flat("Flat", size=size) with pm.Model() as m: with pytest.raises( BlockModelAccessError, - match="Model variables cannot be created in the random function", + match="Model variables cannot be created in the dist function", ): - CustomDist("custom_dist", random=random) + CustomDist("custom_dist", dist=custom_dist) + + def test_api_change_error(self): + def old_random(size): + return pm.Flat.dist(size=size) + + # Old API raises + with pytest.raises(TypeError, match="API change: function passed to `random` argument"): + pm.CustomDist.dist(random=old_random, class_name="custom_dist") + + # New API is fine + pm.CustomDist.dist(dist=old_random, class_name="custom_dist") class TestSymbolicRandomVariable: