diff --git a/docs/source/api/distributions/utilities.rst b/docs/source/api/distributions/utilities.rst index f313eb383c..b5d3da0727 100644 --- a/docs/source/api/distributions/utilities.rst +++ b/docs/source/api/distributions/utilities.rst @@ -9,5 +9,5 @@ Distribution utilities Distribution Discrete Continuous - DensityDist + CustomDist SymbolicRandomVariable diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index e70192b262..6e6319d8e4 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -76,6 +76,7 @@ ) from pymc.distributions.distribution import ( Continuous, + CustomDist, DensityDist, Discrete, Distribution, @@ -154,6 +155,7 @@ "OrderedLogistic", "OrderedProbit", "DensityDist", + "CustomDist", "Distribution", "SymbolicRandomVariable", "Continuous", diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index a60bbc0d01..c96316a8cd 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -26,24 +26,27 @@ from pytensor import tensor as at from pytensor.compile.builders import OpFromGraph from pytensor.graph import node_rewriter -from pytensor.graph.basic import Node, clone_replace +from pytensor.graph.basic import Node, Variable, clone_replace from pytensor.graph.rewriting.basic import in2out from pytensor.graph.utils import MetaType from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.type import RandomType +from pytensor.tensor.random.utils import normalize_size_param from pytensor.tensor.var import TensorVariable from typing_extensions import TypeAlias from pymc.distributions.shape_utils import ( Dims, Shape, + _change_dist_size, convert_dims, convert_shape, convert_size, find_size, shape_from_dims, ) +from pymc.exceptions import BlockModelAccessError from pymc.logprob.abstract import ( MeasurableVariable, _get_measurable_outputs, @@ -52,13 +55,14 @@ _logprob, ) from pymc.logprob.rewriting import logprob_rewrites_db +from pymc.model import BlockModelAccess from pymc.printing import str_for_dist -from pymc.pytensorf import convert_observed_data +from pymc.pytensorf import collect_default_updates, convert_observed_data from pymc.util import UNSET, _add_future_warning_tag from pymc.vartypes import string_types __all__ = [ - "DensityDistRV", + "CustomDist", "DensityDist", "Distribution", "Continuous", @@ -456,15 +460,15 @@ class Continuous(Distribution): """Base class for continuous distributions""" -class DensityDistRV(RandomVariable): +class CustomDistRV(RandomVariable): """ - Base class for DensityDistRV + Base class for CustomDistRV - This should be subclassed when defining custom DensityDist objects. + This should be subclassed when defining CustomDist objects. """ - name = "DensityDistRV" - _print_name = ("DensityDist", "\\operatorname{DensityDist}") + name = "CustomDistRV" + _print_name = ("CustomDist", "\\operatorname{CustomDist}") @classmethod def rng_fn(cls, rng, *args): @@ -473,132 +477,10 @@ def rng_fn(cls, rng, *args): return cls._random_fn(*args, rng=rng, size=size) -class DensityDist(Distribution): - """A distribution that can be used to wrap black-box log density functions. +class _CustomDist(Distribution): + """A distribution that returns a subclass of CustomDistRV""" - Creates a Distribution and registers the supplied log density function to be used - for inference. It is also possible to supply a `random` method in order to be able - to sample from the prior or posterior predictive distributions. - - - Parameters - ---------- - name : str - dist_params : Tuple - A sequence of the distribution's parameter. These will be converted into - PyTensor tensors internally. These parameters could be other ``TensorVariable`` - instances created from , optionally created via ``RandomVariable`` ``Op``s. - class_name : str - Name for the RandomVariable class which will wrap the DensityDist methods. - When not specified, it will be given the name of the variable. - - .. warning:: New DensityDists created with the same class_name will override the - methods dispatched onto the previous classes. If using DensityDists with - different methods across separate models, be sure to use distinct - class_names. - - logp : Optional[Callable] - A callable that calculates the log density of some given observed ``value`` - conditioned on certain distribution parameter values. It must have the - following signature: ``logp(value, *dist_params)``, where ``value`` is - an PyTensor tensor that represents the observed value, and ``dist_params`` - are the tensors that hold the values of the distribution parameters. - This function must return an PyTensor tensor. If ``None``, a ``NotImplemented`` - error will be raised when trying to compute the distribution's logp. - logcdf : Optional[Callable] - A callable that calculates the log cummulative probability of some given observed - ``value`` conditioned on certain distribution parameter values. It must have the - following signature: ``logcdf(value, *dist_params)``, where ``value`` is - an PyTensor tensor that represents the observed value, and ``dist_params`` - are the tensors that hold the values of the distribution parameters. - This function must return an PyTensor tensor. If ``None``, a ``NotImplemented`` - error will be raised when trying to compute the distribution's logcdf. - random : Optional[Callable] - A callable that can be used to generate random draws from the distribution. - It must have the following signature: ``random(*dist_params, rng=None, size=None)``. - The distribution parameters are passed as positional arguments in the - same order as they are supplied when the ``DensityDist`` is constructed. - The keyword arguments are ``rnd``, which will provide the random variable's - associated :py:class:`~numpy.random.Generator`, and ``size``, that will represent - the desired size of the random draw. If ``None``, a ``NotImplemented`` - error will be raised when trying to draw random samples from the distribution's - prior or posterior predictive. - moment : Optional[Callable] - A callable that can be used to compute the moments of the distribution. - It must have the following signature: ``moment(rv, size, *rv_inputs)``. - The distribution's :class:`~pytensor.tensor.random.op.RandomVariable` is passed - as the first argument ``rv``. ``size`` is the random variable's size implied - by the ``dims``, ``size`` and parameters supplied to the distribution. Finally, - ``rv_inputs`` is the sequence of the distribution parameters, in the same order - as they were supplied when the DensityDist was created. If ``None``, a default - ``moment`` function will be assigned that will always return 0, or an array - of zeros. - ndim_supp : int - The number of dimensions in the support of the distribution. Defaults to assuming - a scalar distribution, i.e. ``ndim_supp = 0``. - ndims_params : Optional[Sequence[int]] - The list of number of dimensions in the support of each of the distribution's - parameters. If ``None``, it is assumed that all parameters are scalars, hence - the number of dimensions of their support will be 0. - dtype : str - The dtype of the distribution. All draws and observations passed into the distribution - will be casted onto this dtype. - kwargs : - Extra keyword arguments are passed to the parent's class ``__new__`` method. - - Examples - -------- - .. code-block:: python - - def logp(value, mu): - return -(value - mu)**2 - - with pm.Model(): - mu = pm.Normal('mu',0,1) - pm.DensityDist( - 'density_dist', - mu, - logp=logp, - observed=np.random.randn(100), - ) - idata = pm.sample(100) - - .. code-block:: python - - def logp(value, mu): - return -(value - mu)**2 - - def random(mu, rng=None, size=None): - return rng.normal(loc=mu, scale=1, size=size) - - with pm.Model(): - mu = pm.Normal('mu', 0 , 1) - dens = pm.DensityDist( - 'density_dist', - mu, - logp=logp, - random=random, - observed=np.random.randn(100, 3), - size=(100, 3), - ) - prior = pm.sample_prior_predictive(10).prior_predictive['density_dist'] - assert prior.shape == (1, 10, 100, 3) - - """ - - rv_type = DensityDistRV - - def __new__(cls, name, *args, **kwargs): - kwargs.setdefault("class_name", name) - if isinstance(kwargs.get("observed", None), dict): - raise TypeError( - "Since ``v4.0.0`` the ``observed`` parameter should be of type" - " ``pd.Series``, ``np.array``, or ``pm.Data``." - " Previous versions allowed passing distribution parameters as" - " a dictionary in ``observed``, in the current version these " - "parameters are positional arguments." - ) - return super().__new__(cls, name, *args, **kwargs) + rv_type = CustomDistRV @classmethod def dist( @@ -615,17 +497,7 @@ def dist( **kwargs, ): - if dist_params is None: - dist_params = [] - elif len(dist_params) > 0 and callable(dist_params[0]): - raise TypeError( - "The DensityDist API has changed, you are using the old API " - "where logp was the first positional argument. In the current API, " - "the logp is a keyword argument, amongst other changes. Please refer " - "to the API documentation for more information on how to use the " - "new DensityDist API." - ) - dist_params = [as_tensor_variable(param) for param in dist_params] + dist_params = [as_tensor_variable(param) for param in dist_params] # Assume scalar ndims_params if ndims_params is None: @@ -675,44 +547,533 @@ def rv_op( dtype: str, **kwargs, ): - rv_op = type( - f"DensityDist_{class_name}", - (DensityDistRV,), + rv_type = type( + f"CustomDistRV_{class_name}", + (CustomDistRV,), dict( - name=f"DensityDist_{class_name}", + name=f"CustomDist_{class_name}", inplace=False, ndim_supp=ndim_supp, ndims_params=ndims_params, dtype=dtype, - # Specifc to DensityDist + # Specifc to CustomDist _random_fn=random, ), - )() - - # Register custom logp - rv_type = type(rv_op) + ) + # Dispatch custom methods @_logprob.register(rv_type) - def density_dist_logp(op, value_var_list, *dist_params, **kwargs): - _dist_params = dist_params[3:] - value_var = value_var_list[0] - return logp(value_var, *_dist_params) + def custom_dist_logp(op, values, rng, size, dtype, *dist_params, **kwargs): + return logp(values[0], *dist_params) @_logcdf.register(rv_type) - def density_dist_logcdf(op, var, rvs_to_values, *dist_params, **kwargs): - value_var = rvs_to_values.get(var, var) - return logcdf(value_var, *dist_params, **kwargs) + def density_dist_logcdf(op, value, rng, size, dtype, *dist_params, **kwargs): + return logcdf(value, *dist_params, **kwargs) @_moment.register(rv_type) def density_dist_get_moment(op, rv, rng, size, dtype, *dist_params): return moment(rv, size, *dist_params) + rv_op = rv_type() return rv_op(*dist_params, **kwargs) +class CustomSymbolicDistRV(SymbolicRandomVariable): + """ + Base class for CustomSymbolicDist + + This should be subclassed when defining custom CustomDist objects that have + symbolic random methods. + """ + + default_output = -1 + + _print_name = ("CustomSymbolicDist", "\\operatorname{CustomSymbolicDist}") + + def update(self, node: Node): + op = node.op + inner_updates = collect_default_updates(op.inner_inputs, op.inner_outputs) + + # Map inner updates to outer inputs/outputs + updates = {} + for rng, update in inner_updates.items(): + inp_idx = op.inner_inputs.index(rng) + out_idx = op.inner_outputs.index(update) + updates[node.inputs[inp_idx]] = node.outputs[out_idx] + return updates + + +class _CustomSymbolicDist(Distribution): + + rv_type = CustomSymbolicDistRV + + @classmethod + def dist( + cls, + *dist_params, + class_name: str, + random: Callable, + logp: Optional[Callable] = None, + logcdf: Optional[Callable] = None, + moment: Optional[Callable] = None, + ndim_supp: int = 0, + dtype: str = "floatX", + **kwargs, + ): + warnings.warn( + "CustomDist with symbolic random graph is still experimental. Expect bugs!", + UserWarning, + ) + + dist_params = [as_tensor_variable(param) for param in dist_params] + + if logcdf is None: + logcdf = default_not_implemented(class_name, "logcdf") + + if moment is None: + moment = functools.partial( + default_moment, + rv_name=class_name, + has_fallback=True, + ndim_supp=ndim_supp, + ) + + return super().dist( + dist_params, + class_name=class_name, + logp=logp, + logcdf=logcdf, + random=random, + moment=moment, + ndim_supp=ndim_supp, + **kwargs, + ) + + @classmethod + def rv_op( + cls, + *dist_params, + class_name: str, + random: Callable, + logp: Optional[Callable], + logcdf: Optional[Callable], + moment: Optional[Callable], + size=None, + ndim_supp: int, + ): + size = normalize_size_param(size) + dummy_size_param = size.type() + dummy_dist_params = [dist_param.type() for dist_param in dist_params] + with BlockModelAccess( + error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API" + ): + dummy_rv = random(*dummy_dist_params, dummy_size_param) + dummy_params = [dummy_size_param] + dummy_dist_params + dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,)) + + rv_type = type( + f"CustomSymbolicDistRV_{class_name}", + (CustomSymbolicDistRV,), + # If logp is not provided, we infer it from the random graph + dict( + inline_logprob=logp is None, + ), + ) + + # Dispatch custom methods + if logp is not None: + + @_logprob.register(rv_type) + def custom_dist_logp(op, values, size, *params, **kwargs): + return logp(values[0], *params[: len(dist_params)]) + + @_logcdf.register(rv_type) + def custom_dist_logcdf(op, value, size, *params, **kwargs): + return logcdf(value, *params[: len(dist_params)]) + + @_moment.register(rv_type) + def custom_dist_get_moment(op, rv, size, *params): + return moment(rv, size, *params[: len(params)]) + + @_change_dist_size.register(rv_type) + def change_custom_symbolic_dist_size(op, dist, new_size, expand): + node = dist.owner + + if expand: + shape = tuple(dist.shape) + old_size = shape[: len(shape) - node.op.ndim_supp] + new_size = tuple(new_size) + tuple(old_size) + new_size = at.as_tensor(new_size, ndim=1, dtype="int64") + + old_size, *old_dist_params = node.inputs[: len(dist_params) + 1] + + # OpFromGraph has to be recreated if the size type changes! + dummy_size_param = new_size.type() + dummy_dist_params = [dist_param.type() for dist_param in old_dist_params] + dummy_rv = random(*dummy_dist_params, dummy_size_param) + dummy_params = [dummy_size_param] + dummy_dist_params + dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,)) + new_rv_op = rv_type( + inputs=dummy_params, + outputs=[*dummy_updates_dict.values(), dummy_rv], + ndim_supp=ndim_supp, + ) + new_rv = new_rv_op(new_size, *dist_params) + + return new_rv + + rv_op = rv_type( + inputs=dummy_params, + outputs=[*dummy_updates_dict.values(), dummy_rv], + ndim_supp=ndim_supp, + ) + return rv_op(size, *dist_params) + + +class CustomDist: + """A helper class to create custom distributions + + This class can be used to wrap black-box random and logp methods for use in + forward and mcmc sampling. + + A user can provide a `random` function that returns numerical draws (e.g., via + NumPy routines) or an Aesara graph that represents the random graph when evaluated. + + A user can provide a `logp` function that must return an Aesara graph that + represents the logp graph when evaluated. This is used for mcmc sampling. In some + cases, if a user provides a `random` function that returns an Aesara graph, PyMC + will be able to automatically derive the appropriate `logp` graph when performing + MCMC sampling. + + Additionally, a user can provide a `logcdf` and `moment` functions that must return + an Aesara graph that computes those quantities. These may be used by other PyMC + routines. + + Parameters + ---------- + name : str + dist_params : Tuple + A sequence of the distribution's parameter. These will be converted into + Pytensor tensor variables internally. + class_name : str + Name for the class which will wrap the CustomDist methods. When not specified, + it will be given the name of the model variable. + + .. warning:: New CustomDists created with the same class_name will override the + methods dispatched onto the previous classes. If using CustomDists with + different methods across separate models, be sure to use distinct + class_names. + + random : Optional[Callable] + A callable that can be used to 1) generate random draws from the distribution or + 2) returns an Aesara graph that represents the random draws. + + If 1) it must have the following signature: ``random(*dist_params, rng=None, size=None)``. + The numerical distribution parameters are passed as positional arguments in the + same order as they are supplied when the ``CustomDist`` is constructed. + The keyword arguments are ``rng``, which will provide the random variable's + associated :py:class:`~numpy.random.Generator`, and ``size``, that will represent + the desired size of the random draw. If ``None``, a ``NotImplemented`` + error will be raised when trying to draw random samples from the distribution's + prior or posterior predictive. + + If 2) it must have the following signature: ``random(*dist_params, size)``. + The symbolic tensor distribution parameters are passed as postional arguments in + the same order as they are supplied when the ``CustomDist`` is constructed. + logp : Optional[Callable] + A callable that calculates the log probability of some given ``value`` + conditioned on certain distribution parameter values. It must have the + following signature: ``logp(value, *dist_params)``, where ``value`` is + an PyTensor tensor that represents the distribution value, and ``dist_params`` + are the tensors that hold the values of the distribution parameters. + This function must return an PyTensor tensor. + + When the `random` function is specified and returns an `Aesara` graph, PyMC + will try to automatically infer the `logp` when this is not provided. + + Otherwise, a ``NotImplementedError`` will be raised when trying to compute the + distribution's logp. + logcdf : Optional[Callable] + A callable that calculates the log cumulative log probability of some given + ``value`` conditioned on certain distribution parameter values. It must have the + following signature: ``logcdf(value, *dist_params)``, where ``value`` is + an PyTensor tensor that represents the distribution value, and ``dist_params`` + are the tensors that hold the values of the distribution parameters. + This function must return an PyTensor tensor. If ``None``, a ``NotImplementedError`` + will be raised when trying to compute the distribution's logcdf. + moment : Optional[Callable] + A callable that can be used to compute the moments of the distribution. + It must have the following signature: ``moment(rv, size, *rv_inputs)``. + The distribution's variable is passed as the first argument ``rv``. ``size`` + is the random variable's size implied by the ``dims``, ``size`` and parameters + supplied to the distribution. Finally, ``rv_inputs`` is the sequence of the + distribution parameters, in the same order as they were supplied when the + CustomDist was created. If ``None``, a default ``moment`` function will be + assigned that will always return 0, or an array of zeros. + ndim_supp : int + The number of dimensions in the support of the distribution. Defaults to assuming + a scalar distribution, i.e. ``ndim_supp = 0``. + ndims_params : Optional[Sequence[int]] + The list of number of dimensions in the support of each of the distribution's + parameters. If ``None``, it is assumed that all parameters are scalars, hence + the number of dimensions of their support will be 0. This is not needed if an + Aesara random function is provided + dtype : str + The dtype of the distribution. All draws and observations passed into the + distribution will be cast onto this dtype. This is not needed if an Aesara + random function is provided, which should already return the right dtype! + kwargs : + Extra keyword arguments are passed to the parent's class ``__new__`` method. + + + Examples + -------- + + Create a CustomDist that wraps a black-box logp function. This variable cannot be + used in prior or posterior predictive sampling because no random function was provided + + .. code-block:: python + + import numpy as np + import pymc as pm + from pytensor.tensor import TensorVariable + + def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable: + return -(value - mu)**2 + + with pm.Model(): + mu = pm.Normal('mu',0,1) + pm.CustomDist( + 'custom_dist', + mu, + logp=logp, + observed=np.random.randn(100), + ) + idata = pm.sample(100) + + Provide a random function that return numerical draws. This allows one to use a + CustomDist in prior and posterior predictive sampling. + + .. code-block:: python + + from typing import Optional, Tuple + + import numpy as np + import pymc as pm + from pytensor.tensor import TensorVariable + + def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable: + return -(value - mu)**2 + + def random( + mu: np.ndarray | float, + rng: Optional[np.random.Generator] = None, + size : Optional[Tuple[int]]=None, + ) -> np.ndarray | float : + return rng.normal(loc=mu, scale=1, size=size) + + with pm.Model(): + mu = pm.Normal('mu', 0 , 1) + pm.CustomDist( + 'custom_dist', + mu, + logp=logp, + random=random, + observed=np.random.randn(100, 3), + size=(100, 3), + ) + prior = pm.sample_prior_predictive(10) + + Provide a random function that creates an Aesara random graph. PyMC can + automatically infer that the logp of this variable corresponds to a shifted + Exponential distribution. + + .. code-block:: python + + import pymc as pm + from pytensor.tensor import TensorVariable + + def random( + lam: TensorVariable, + shift: TensorVariable, + size: TensorVariable, + ) -> TensorVariable: + return pm.Exponential.dist(lam, size=size) + shift + + with pm.Model() as m: + lam = pm.HalfNormal("lam") + shift = -1 + pm.CustomDist( + "custom_dist", + lam, + shift, + random=random, + observed=[-1, -1, 0], + ) + + prior = pm.sample_prior_predictive() + posterior = pm.sample() + + Provide a random function that creates an Aesara random graph. PyMC can + automatically infer that the logp of this variable corresponds to a + modified-PERT distribution. + + .. code-block:: python + + import pymc as pm + from pytensor.tensor import TensorVariable + + def pert( + low: Tensorvariable, + peak: Tensorvariable, + high: Tensorvariable, + lmbda: Tensorvariable, + size: Tensorvariable, + ) -> Tensorvariable: + range = (high - low) + s_alpha = 1 + lmbda * (peak - low) / range + s_beta = 1 + lmbda * (high - peak) / range + return pm.Beta.dist(s_alpha, s_beta, size=size) * range + low + + with pm.Model() as m: + low = pm.Normal("low", 0, 10) + peak = pm.Normal("peak", 50, 10) + high = pm.Normal("high", 100, 10) + lmbda = 4 + pm.CustomDist("pert", low, peak, high, lmbda, random=pert, observed=[30, 35, 73]) + + m.point_logps() + + """ + + def __new__( + cls, + name, + *dist_params, + random: Optional[Callable] = None, + logp: Optional[Callable] = None, + logcdf: Optional[Callable] = None, + moment: Optional[Callable] = None, + ndim_supp: int = 0, + ndims_params: Optional[Sequence[int]] = None, + dtype: str = "floatX", + **kwargs, + ): + if isinstance(kwargs.get("observed", None), dict): + raise TypeError( + "Since ``v4.0.0`` the ``observed`` parameter should be of type" + " ``pd.Series``, ``np.array``, or ``pm.Data``." + " Previous versions allowed passing distribution parameters as" + " a dictionary in ``observed``, in the current version these " + "parameters are positional arguments." + ) + dist_params = cls.parse_dist_params(dist_params) + if cls.is_symbolic_random(random, dist_params): + return _CustomSymbolicDist( + name, + *dist_params, + class_name=name, + random=random, + logp=logp, + logcdf=logcdf, + moment=moment, + ndim_supp=ndim_supp, + **kwargs, + ) + else: + return _CustomDist( + name, + *dist_params, + class_name=name, + random=random, + logp=logp, + logcdf=logcdf, + moment=moment, + ndim_supp=ndim_supp, + ndims_params=ndims_params, + dtype=dtype, + **kwargs, + ) + return super().__new__(cls, name, *args, **kwargs) + + @classmethod + def dist( + cls, + *dist_params, + class_name: str, + random: Optional[Callable] = None, + logp: Optional[Callable] = None, + logcdf: Optional[Callable] = None, + moment: Optional[Callable] = None, + ndim_supp: int = 0, + ndims_params: Optional[Sequence[int]] = None, + dtype: str = "floatX", + **kwargs, + ): + dist_params = cls.parse_dist_params(dist_params) + if cls.is_symbolic_random(random, dist_params): + return _CustomSymbolicDist.dist( + *dist_params, + class_name=class_name, + random=random, + logp=logp, + logcdf=logcdf, + moment=moment, + ndim_supp=ndim_supp, + **kwargs, + ) + else: + return _CustomDist.dist( + *dist_params, + class_name=class_name, + random=random, + logp=logp, + logcdf=logcdf, + moment=moment, + ndim_supp=ndim_supp, + ndims_params=ndims_params, + dtype=dtype, + **kwargs, + ) + + @classmethod + def parse_dist_params(cls, dist_params): + if len(dist_params) > 0 and callable(dist_params[0]): + raise TypeError( + "The DensityDist API has changed, you are using the old API " + "where logp was the first positional argument. In the current API, " + "the logp is a keyword argument, amongst other changes. Please refer " + "to the API documentation for more information on how to use the " + "new DensityDist API." + ) + return [as_tensor_variable(param) for param in dist_params] + + @classmethod + def is_symbolic_random(self, random, dist_params): + if random is None: + return False + # Try calling random with symbolic inputs + try: + size = normalize_size_param(None) + with BlockModelAccess( + error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API" + ): + out = random(*dist_params, size) + except BlockModelAccessError: + raise + except Exception: + # If it fails we assume it was not + return False + # Confirm the output is symbolic + return isinstance(out, Variable) + + +DensityDist = CustomDist + + def default_not_implemented(rv_name, method_name): message = ( - f"Attempted to run {method_name} on the DensityDist '{rv_name}', " + f"Attempted to run {method_name} on the CustomDist '{rv_name}', " f"but this method had not been provided when the distribution was " f"constructed. Please re-build your model and provide a callable " f"to '{rv_name}'s {method_name} keyword argument.\n" diff --git a/pymc/exceptions.py b/pymc/exceptions.py index 5b4141f303..7a18167d5c 100644 --- a/pymc/exceptions.py +++ b/pymc/exceptions.py @@ -82,3 +82,7 @@ class TruncationError(RuntimeError): class NotConstantValueError(ValueError): pass + + +class BlockModelAccessError(RuntimeError): + pass diff --git a/pymc/model.py b/pymc/model.py index 8c9d85af2b..2bb703f144 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -53,7 +53,13 @@ from pymc.data import GenTensorVariable, Minibatch from pymc.distributions.logprob import _joint_logp from pymc.distributions.transforms import _default_transform -from pymc.exceptions import ImputationWarning, SamplingError, ShapeError, ShapeWarning +from pymc.exceptions import ( + BlockModelAccessError, + ImputationWarning, + SamplingError, + ShapeError, + ShapeWarning, +) from pymc.initial_point import make_initial_point_fn from pymc.pytensorf import ( PointFunc, @@ -195,6 +201,8 @@ def get_context(cls, error_if_none=True) -> Optional[T]: if error_if_none: raise TypeError(f"No {cls} on context stack") return None + if isinstance(candidate, BlockModelAccess): + raise BlockModelAccessError(candidate.error_msg_on_access) return candidate def get_contexts(cls) -> List[T]: @@ -1798,6 +1806,13 @@ def point_logps(self, point=None, round_vals=2): Model._context_class = Model +class BlockModelAccess(Model): + """This class can be used to prevent user access to Model contexts""" + + def __init__(self, *args, error_msg_on_access="Model access is blocked", **kwargs): + self.error_msg_on_access = error_msg_on_access + + def set_data(new_data, model=None, *, coords=None): """Sets the value of one or more data container variables. Note that the shape is also dynamic, it is updated when the value is changed. See the examples below for two common diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 7cf6e1946b..7bd52c5045 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -1041,6 +1041,46 @@ def reseed_rngs( rng.set_value(new_rng, borrow=True) +def collect_default_updates( + inputs: Sequence[Variable], outputs: Sequence[Variable] +) -> Dict[Variable, Variable]: + """Collect default update expression of RVs between inputs and outputs""" + + # Avoid circular import + from pymc.distributions.distribution import SymbolicRandomVariable + + rng_updates = {} + output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs] + for random_var in ( + var + for var in vars_between(inputs, output_to_list) + if var.owner + and isinstance(var.owner.op, (RandomVariable, SymbolicRandomVariable)) + and var not in inputs + ): + # All nodes in `vars_between(inputs, outputs)` have owners. + # But mypy doesn't know, so we just assert it: + assert random_var.owner.op is not None + if isinstance(random_var.owner.op, RandomVariable): + rng = random_var.owner.inputs[0] + if hasattr(rng, "default_update"): + update_map = {rng: rng.default_update} + else: + update_map = {rng: random_var.owner.outputs[0]} + else: + update_map = random_var.owner.op.update(random_var.owner) + # Check that we are not setting different update expressions for the same variables + for rng, update in update_map.items(): + if rng not in rng_updates: + rng_updates[rng] = update + # When a variable has multiple outputs, it will be called twice with the same + # update expression. We don't want to raise in that case, only if the update + # expression in different from the one already registered + elif rng_updates[rng] is not update: + raise ValueError(f"Multiple update expressions found for the variable {rng}") + return rng_updates + + def compile_pymc( inputs, outputs, @@ -1082,40 +1122,9 @@ def compile_pymc( this function is called within a model context and the model `check_bounds` flag is set to False. """ - # Avoid circular import - from pymc.distributions.distribution import SymbolicRandomVariable - # Create an update mapping of RandomVariable's RNG so that it is automatically # updated after every function call - rng_updates = {} - output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs] - for random_var in ( - var - for var in vars_between(inputs, output_to_list) - if var.owner - and isinstance(var.owner.op, (RandomVariable, SymbolicRandomVariable)) - and var not in inputs - ): - # All nodes in `vars_between(inputs, outputs)` have owners. - # But mypy doesn't know, so we just assert it: - assert random_var.owner.op is not None - if isinstance(random_var.owner.op, RandomVariable): - rng = random_var.owner.inputs[0] - if hasattr(rng, "default_update"): - update_map = {rng: rng.default_update} - else: - update_map = {rng: random_var.owner.outputs[0]} - else: - update_map = random_var.owner.op.update(random_var.owner) - # Check that we are not setting different update expressions for the same variables - for rng, update in update_map.items(): - if rng not in rng_updates: - rng_updates[rng] = update - # When a variable has multiple outputs, it will be called twice with the same - # update expression. We don't want to raise in that case, only if the update - # expression in different from the one already registered - elif rng_updates[rng] is not update: - raise ValueError(f"Multiple update expressions found for the variable {rng}") + rng_updates = collect_default_updates(inputs, outputs) # We always reseed random variables as this provides RNGs with no chances of collision if rng_updates: diff --git a/pymc/tests/distributions/test_distribution.py b/pymc/tests/distributions/test_distribution.py index 75d560e878..ae7d625ac3 100644 --- a/pymc/tests/distributions/test_distribution.py +++ b/pymc/tests/distributions/test_distribution.py @@ -25,10 +25,30 @@ import pymc as pm -from pymc.distributions import DiracDelta, Flat, MvNormal, MvStudentT, logp -from pymc.distributions.distribution import SymbolicRandomVariable, _moment, moment -from pymc.distributions.shape_utils import change_dist_size, to_tuple -from pymc.logprob.abstract import get_measurable_outputs +from pymc.distributions import ( + DiracDelta, + Flat, + HalfNormal, + LogNormal, + MvNormal, + MvStudentT, + Normal, + logp, +) +from pymc.distributions.distribution import ( + CustomDist, + CustomDistRV, + CustomSymbolicDistRV, + SymbolicRandomVariable, + _moment, + moment, +) +from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple +from pymc.distributions.transforms import log +from pymc.exceptions import BlockModelAccessError +from pymc.logprob.abstract import get_measurable_outputs, logcdf +from pymc.model import Model +from pymc.sampling import draw, sample from pymc.tests.distributions.util import assert_moment_is_expected from pymc.util import _FutureWarningValidatingScratchpad @@ -104,7 +124,7 @@ def test_all_distributions_have_moments(): dist_module.Distribution, dist_module.Discrete, dist_module.Continuous, - dist_module.DensityDist, + dist_module.CustomDist, dist_module.simulator.Simulator, } @@ -134,20 +154,21 @@ def test_all_distributions_have_moments(): ) -class TestDensityDist: +class TestCustomDist: @pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str) - def test_density_dist_with_random(self, size): - with pm.Model() as model: - mu = pm.Normal("mu", 0, 1) - obs = pm.DensityDist( - "density_dist", + def test_custom_dist_with_random(self, size): + with Model() as model: + mu = Normal("mu", 0, 1) + obs = CustomDist( + "custom_dist", mu, random=lambda mu, rng=None, size=None: rng.normal(loc=mu, scale=1, size=size), observed=np.random.randn(100, *size), ) + assert isinstance(obs.owner.op, CustomDistRV) assert obs.eval().shape == (100,) + size - def test_density_dist_with_random_invalid_observed(self): + def test_custom_dist_with_random_invalid_observed(self): with pytest.raises( TypeError, match=( @@ -159,37 +180,38 @@ def test_density_dist_with_random_invalid_observed(self): ), ): size = (3,) - with pm.Model() as model: - mu = pm.Normal("mu", 0, 1) - pm.DensityDist( - "density_dist", + with Model() as model: + mu = Normal("mu", 0, 1) + CustomDist( + "custom_dist", mu, random=lambda mu, rng=None, size=None: rng.normal(loc=mu, scale=1, size=size), observed={"values": np.random.randn(100, *size)}, ) - def test_density_dist_without_random(self): - with pm.Model() as model: - mu = pm.Normal("mu", 0, 1) - pm.DensityDist( - "density_dist", + def test_custom_dist_without_random(self): + with Model() as model: + mu = Normal("mu", 0, 1) + custom_dist = CustomDist( + "custom_dist", mu, logp=lambda value, mu: logp(pm.Normal.dist(mu, 1, size=100), value), observed=np.random.randn(100), initval=0, ) - idata = pm.sample(tune=50, draws=100, cores=1, step=pm.Metropolis()) + assert isinstance(custom_dist.owner.op, CustomDistRV) + idata = sample(tune=50, draws=100, cores=1, step=pm.Metropolis()) with pytest.raises(NotImplementedError): pm.sample_posterior_predictive(idata, model=model) @pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str) - def test_density_dist_with_random_multivariate(self, size): + def test_custom_dist_with_random_multivariate(self, size): supp_shape = 5 - with pm.Model() as model: - mu = pm.Normal("mu", 0, 1, size=supp_shape) - obs = pm.DensityDist( - "density_dist", + with Model() as model: + mu = Normal("mu", 0, 1, size=supp_shape) + obs = CustomDist( + "custom_dist", mu, random=lambda mu, rng=None, size=None: rng.multivariate_normal( mean=mu, cov=np.eye(len(mu)), size=size @@ -199,44 +221,47 @@ def test_density_dist_with_random_multivariate(self, size): ndim_supp=1, ) + assert isinstance(obs.owner.op, CustomDistRV) assert obs.eval().shape == (100,) + size + (supp_shape,) - def test_serialize_density_dist(self): + def test_serialize_custom_dist(self): def func(x): return -2 * (x**2).sum() def random(rng, size): return rng.uniform(-2, 2, size=size) - with pm.Model(): - pm.Normal("x") - y = pm.DensityDist("y", logp=func, random=random) + with Model(): + Normal("x") + y = CustomDist("y", logp=func, random=random) + assert isinstance(y.owner.op, CustomDistRV) with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - pm.sample(draws=5, tune=1, mp_ctx="spawn") + sample(draws=5, tune=1, mp_ctx="spawn") import cloudpickle cloudpickle.loads(cloudpickle.dumps(y)) - def test_density_dist_old_api_error(self): - with pm.Model(): + def test_custom_dist_old_api_error(self): + with Model(): with pytest.raises( TypeError, match="The DensityDist API has changed, you are using the old API" ): - pm.DensityDist("a", lambda x: x) + CustomDist("a", lambda x: x) @pytest.mark.parametrize("size", [None, (), (2,)], ids=str) - def test_density_dist_multivariate_logp(self, size): + def test_custom_dist_multivariate_logp(self, size): supp_shape = 5 - with pm.Model() as model: + with Model() as model: def logp(value, mu): return pm.MvNormal.logp(value, mu, at.eye(mu.shape[0])) - mu = pm.Normal("mu", size=supp_shape) - a = pm.DensityDist("a", mu, logp=logp, ndims_params=[1], ndim_supp=1, size=size) + mu = Normal("mu", size=supp_shape) + a = CustomDist("a", mu, logp=logp, ndims_params=[1], ndim_supp=1, size=size) + assert isinstance(a.owner.op, CustomDistRV) mu_test_value = npr.normal(loc=0, scale=1, size=supp_shape).astype(pytensor.config.floatX) a_test_value = npr.normal( loc=mu_test_value, scale=1, size=to_tuple(size) + (supp_shape,) @@ -253,37 +278,38 @@ def logp(value, mu): ("custom_moment", (2, 5), np.full((2, 5), 5)), ], ) - def test_density_dist_default_moment_univariate(self, moment, size, expected): + def test_custom_dist_default_moment_univariate(self, moment, size, expected): if moment == "custom_moment": moment = lambda rv, size, *rv_inputs: 5 * at.ones(size, dtype=rv.dtype) with pm.Model() as model: - pm.DensityDist("x", moment=moment, size=size) + x = CustomDist("x", moment=moment, size=size) + assert isinstance(x.owner.op, CustomDistRV) assert_moment_is_expected(model, expected, check_finite_logp=False) @pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str) - def test_density_dist_custom_moment_univariate(self, size): + def test_custom_dist_custom_moment_univariate(self, size): def density_moment(rv, size, mu): return (at.ones(size) * mu).astype(rv.dtype) mu_val = np.array(np.random.normal(loc=2, scale=1)).astype(pytensor.config.floatX) - with pm.Model(): - mu = pm.Normal("mu") - a = pm.DensityDist("a", mu, moment=density_moment, size=size) + with Model(): + mu = Normal("mu") + a = CustomDist("a", mu, moment=density_moment, size=size) + assert isinstance(a.owner.op, CustomDistRV) evaled_moment = moment(a).eval({mu: mu_val}) assert evaled_moment.shape == to_tuple(size) assert np.all(evaled_moment == mu_val) @pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str) - def test_density_dist_custom_moment_multivariate(self, size): + def test_custom_dist_custom_moment_multivariate(self, size): def density_moment(rv, size, mu): return (at.ones(size)[..., None] * mu).astype(rv.dtype) mu_val = np.random.normal(loc=2, scale=1, size=5).astype(pytensor.config.floatX) - with pm.Model(): - mu = pm.Normal("mu", size=5) - a = pm.DensityDist( - "a", mu, moment=density_moment, ndims_params=[1], ndim_supp=1, size=size - ) + with Model(): + mu = Normal("mu", size=5) + a = CustomDist("a", mu, moment=density_moment, ndims_params=[1], ndim_supp=1, size=size) + assert isinstance(a.owner.op, CustomDistRV) evaled_moment = moment(a).eval({mu: mu_val}) assert evaled_moment.shape == to_tuple(size) + (5,) assert np.all(evaled_moment == mu_val) @@ -298,7 +324,7 @@ def density_moment(rv, size, mu): (False, (2,)), ], ) - def test_density_dist_default_moment_multivariate(self, with_random, size): + def test_custom_dist_default_moment_multivariate(self, with_random, size): def _random(mu, rng=None, size=None): return rng.normal(mu, scale=1, size=to_tuple(size) + mu.shape) @@ -308,9 +334,10 @@ def _random(mu, rng=None, size=None): random = None mu_val = np.random.normal(loc=2, scale=1, size=5).astype(pytensor.config.floatX) - with pm.Model(): - mu = pm.Normal("mu", size=5) - a = pm.DensityDist("a", mu, random=random, ndims_params=[1], ndim_supp=1, size=size) + with Model(): + mu = Normal("mu", size=5) + a = CustomDist("a", mu, random=random, ndims_params=[1], ndim_supp=1, size=size) + assert isinstance(a.owner.op, CustomDistRV) if with_random: evaled_moment = moment(a).eval({mu: mu_val}) assert evaled_moment.shape == to_tuple(size) + (5,) @@ -324,7 +351,7 @@ def _random(mu, rng=None, size=None): def test_dist(self): mu = 1 - x = pm.DensityDist.dist( + x = pm.CustomDist.dist( mu, class_name="test", logp=lambda value, mu: pm.logp(pm.Normal.dist(mu), value), @@ -339,6 +366,132 @@ def test_dist(self): assert np.allclose(x_logp.eval(), st.norm(1).logpdf(test_value)) +class TestCustomSymbolicDist: + def test_basic(self): + def custom_random(mu, sigma, size): + return at.exp(pm.Normal.dist(mu, sigma, size=size)) + + with Model() as m: + mu = Normal("mu") + sigma = HalfNormal("sigma") + with pytest.warns(UserWarning, match="experimental"): + lognormal = CustomDist( + "lognormal", + mu, + sigma, + random=custom_random, + size=(10,), + transform=log, + initval=np.ones(10), + ) + + assert isinstance(lognormal.owner.op, CustomSymbolicDistRV) + + # Fix mu and sigma, so that all source of randomness comes from the symbolic RV + draws = pm.draw(lognormal, draws=3, givens={mu: 0.0, sigma: 1.0}) + assert draws.shape == (3, 10) + assert np.unique(draws).size == 30 + + with Model() as ref_m: + mu = Normal("mu") + sigma = HalfNormal("sigma") + LogNormal("lognormal", mu, sigma, size=(10,)) + + ip = m.initial_point() + np.testing.assert_allclose(m.compile_logp()(ip), ref_m.compile_logp()(ip)) + + def test_random_multiple_rngs(self): + def custom_random(p, sigma, size): + idx = pm.Bernoulli.dist(p=p) + comps = pm.Normal.dist([-sigma, sigma], 1e-1, size=(*size, 2)).T + return comps[idx] + + with pytest.warns(UserWarning, match="experimental"): + customdist = CustomDist.dist( + 0.5, + 10.0, + class_name="customdist", + random=custom_random, + size=(10,), + ) + + assert isinstance(customdist.owner.op, CustomSymbolicDistRV) + + node = customdist.owner + assert len(node.inputs) == 5 # Size, 2 inputs and 2 RNGs + assert len(node.outputs) == 3 # RV and 2 updated RNGs + assert len(node.op.update(node)) == 2 + + draws = pm.draw(customdist, draws=2, random_seed=123) + assert np.unique(draws).size == 20 + + def test_custom_methods(self): + def custom_random(mu, size): + if rv_size_is_none(size): + return mu + return at.full(size, mu) + + def custom_moment(rv, size, mu): + return at.full_like(rv, mu + 1) + + def custom_logp(value, mu): + return at.full_like(value, mu + 2) + + def custom_logcdf(value, mu): + return at.full_like(value, mu + 3) + + with pytest.warns(UserWarning, match="experimental"): + customdist = CustomDist.dist( + [np.e, np.e], + class_name="customdist", + random=custom_random, + moment=custom_moment, + logp=custom_logp, + logcdf=custom_logcdf, + ) + + assert isinstance(customdist.owner.op, CustomSymbolicDistRV) + + np.testing.assert_allclose(draw(customdist), [np.e, np.e]) + np.testing.assert_allclose(moment(customdist).eval(), [np.e + 1, np.e + 1]) + np.testing.assert_allclose(logp(customdist, [0, 0]).eval(), [np.e + 2, np.e + 2]) + np.testing.assert_allclose(logcdf(customdist, [0, 0]).eval(), [np.e + 3, np.e + 3]) + + def test_change_size(self): + def custom_random(mu, sigma, size): + return at.exp(pm.Normal.dist(mu, sigma, size=size)) + + with pytest.warns(UserWarning, match="experimental"): + lognormal = CustomDist.dist( + 0, + 1, + class_name="lognormal", + random=custom_random, + size=(10,), + ) + assert isinstance(lognormal.owner.op, CustomSymbolicDistRV) + assert tuple(lognormal.shape.eval()) == (10,) + + new_lognormal = change_dist_size(lognormal, new_size=(2, 5)) + assert isinstance(new_lognormal.owner.op, CustomSymbolicDistRV) + assert tuple(new_lognormal.shape.eval()) == (2, 5) + + new_lognormal = change_dist_size(lognormal, new_size=(2, 5), expand=True) + assert isinstance(new_lognormal.owner.op, CustomSymbolicDistRV) + assert tuple(new_lognormal.shape.eval()) == (2, 5, 10) + + def test_error_model_access(self): + def random(size): + return pm.Flat("Flat", size=size) + + with pm.Model() as m: + with pytest.raises( + BlockModelAccessError, + match="Model variables cannot be created in the random function", + ): + CustomDist("custom_dist", random=random) + + class TestSymbolicRandomVarible: def test_inline(self): class TestSymbolicRV(SymbolicRandomVariable): diff --git a/pymc/tests/distributions/test_logprob.py b/pymc/tests/distributions/test_logprob.py index b775107b56..0133b4383a 100644 --- a/pymc/tests/distributions/test_logprob.py +++ b/pymc/tests/distributions/test_logprob.py @@ -32,7 +32,7 @@ import pymc as pm -from pymc import DensityDist +from pymc.distributions import CustomDist from pymc.distributions.continuous import ( HalfFlat, LogNormal, @@ -309,7 +309,7 @@ def test_model_unchanged_logprob_access(): def test_unexpected_rvs(): with Model() as model: x = Normal("x") - y = DensityDist("y", logp=lambda *args: x) + y = CustomDist("y", logp=lambda *args: x) with pytest.raises(ValueError, match="^Random variables detected in the logp graph"): model.logp() @@ -339,7 +339,7 @@ def logp(value, x): with Model() as m: x = Normal.dist() - y = DensityDist("y", x, logp=logp) + y = CustomDist("y", x, logp=logp) with pytest.warns( UserWarning, match="Found a random variable that was neither among the observations " @@ -355,7 +355,7 @@ def logp(value, x): # The above warning should go away with ignore_logprob. with Model() as m: x = ignore_logprob(Normal.dist()) - y = DensityDist("y", x, logp=logp) + y = CustomDist("y", x, logp=logp) with warnings.catch_warnings(): warnings.simplefilter("error") assert _joint_logp( diff --git a/pymc/tests/sampling/test_forward.py b/pymc/tests/sampling/test_forward.py index 4eaa795048..bccd9b3b1b 100644 --- a/pymc/tests/sampling/test_forward.py +++ b/pymc/tests/sampling/test_forward.py @@ -1135,7 +1135,7 @@ def test_density_dist(self): with pm.Model(): mu = pm.Normal("mu", 0, 1) sigma = pm.HalfNormal("sigma", 1e-6) - a = pm.DensityDist( + a = pm.CustomDist( "a", mu, sigma, diff --git a/pymc/tests/sampling/test_parallel.py b/pymc/tests/sampling/test_parallel.py index 76823f8ed9..0663b6d830 100644 --- a/pymc/tests/sampling/test_parallel.py +++ b/pymc/tests/sampling/test_parallel.py @@ -206,7 +206,7 @@ def test_spawn_densitydist_function(): def func(x): return -2 * (x**2).sum() - obs = pm.DensityDist("density_dist", logp=func, observed=np.random.randn(100)) + obs = pm.CustomDist("density_dist", logp=func, observed=np.random.randn(100)) with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn") @@ -222,7 +222,7 @@ def logp(x, mu): out = pm.logp(normal_dist, x) return out - obs = pm.DensityDist("density_dist", mu, logp=logp, observed=np.random.randn(N), size=N) + obs = pm.CustomDist("density_dist", mu, logp=logp, observed=np.random.randn(N), size=N) with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn") diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index b614ca542e..b5276f9e51 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -401,7 +401,7 @@ def test_multiple_observed_rv(): y2_data = np.random.randn(100) with pm.Model() as model: mu = pm.Normal("mu") - x = pm.DensityDist( # pylint: disable=unused-variable + x = pm.CustomDist( # pylint: disable=unused-variable "x", mu, logp=lambda value, mu: pm.Normal.logp(value, mu, 1.0), observed=0.1 ) assert not model["x"] == model["mu"]