From 4e58c46e0458157ce80ed460dc13f5fbb65ee2db Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 4 May 2025 09:41:37 -0400 Subject: [PATCH 01/15] port modules --- pymc_extras/deserialize.py | 231 ++++++ pymc_extras/prior.py | 1348 ++++++++++++++++++++++++++++++++++++ 2 files changed, 1579 insertions(+) create mode 100644 pymc_extras/deserialize.py create mode 100644 pymc_extras/prior.py diff --git a/pymc_extras/deserialize.py b/pymc_extras/deserialize.py new file mode 100644 index 000000000..e73cc3049 --- /dev/null +++ b/pymc_extras/deserialize.py @@ -0,0 +1,231 @@ +"""Deserialize into a PyMC-Marketing object. + +This is a two step process: + +1. Determine if the data is of the correct type. +2. Deserialize the data into a python object. + +This is used to deserialize JSON data into PyMC-Marketing objects +throughout the package. + +Examples +-------- +Make use of the already registered PyMC-Marketing deserializers: + +.. code-block:: python + + from pymc_extras.deserialize import deserialize + + prior_class_data = { + "dist": "Normal", + "kwargs": {"mu": 0, "sigma": 1} + } + prior = deserialize(prior_class_data) + # Prior("Normal", mu=0, sigma=1) + +Register custom class deserialization: + +.. code-block:: python + + from pymc_extras.deserialize import register_deserialization + + class MyClass: + def __init__(self, value: int): + self.value = value + + def to_dict(self) -> dict: + # Example of what the to_dict method might look like. + return {"value": self.value} + + register_deserialization( + is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int), + deserialize=lambda data: MyClass(value=data["value"]), + ) + +Deserialize data into that custom class: + +.. code-block:: python + + from pymc_extras.deserialize import deserialize + + data = {"value": 42} + obj = deserialize(data) + assert isinstance(obj, MyClass) + + +""" + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +IsType = Callable[[Any], bool] +Deserialize = Callable[[Any], Any] + + +@dataclass +class Deserializer: + """Object to store information required for deserialization. + + All deserializers should be stored via the :func:`register_deserialization` function + instead of creating this object directly. + + Attributes + ---------- + is_type : IsType + Function to determine if the data is of the correct type. + deserialize : Deserialize + Function to deserialize the data. + + Examples + -------- + .. code-block:: python + + from typing import Any + + class MyClass: + def __init__(self, value: int): + self.value = value + + from pymc_extras.deserialize import Deserializer + + def is_type(data: Any) -> bool: + return data.keys() == {"value"} and isinstance(data["value"], int) + + def deserialize(data: dict) -> MyClass: + return MyClass(value=data["value"]) + + deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize) + + """ + + is_type: IsType + deserialize: Deserialize + + +DESERIALIZERS: list[Deserializer] = [] + + +class DeserializableError(Exception): + """Error raised when data cannot be deserialized.""" + + def __init__(self, data: Any): + self.data = data + super().__init__( + f"Couldn't deserialize {data}. Use register_deserialization to add a deserialization mapping." + ) + + +def deserialize(data: Any) -> Any: + """Deserialize a dictionary into a Python object. + + Use the :func:`register_deserialization` function to add custom deserializations. + + Deserialization is a two step process due to the dynamic nature of the data: + + 1. Determine if the data is of the correct type. + 2. Deserialize the data into a Python object. + + Each registered deserialization is checked in order until one is found that can + deserialize the data. If no deserialization is found, a :class:`DeserializableError` is raised. + + A :class:`DeserializableError` is raised when the data fails to be deserialized + by any of the registered deserializers. + + Parameters + ---------- + data : Any + The data to deserialize. + + Returns + ------- + Any + The deserialized object. + + Raises + ------ + DeserializableError + Raised when the data doesn't match any registered deserializations + or fails to be deserialized. + + Examples + -------- + Deserialize a :class:`pymc_extras.prior.Prior` object: + + .. code-block:: python + + from pymc_extras.deserialize import deserialize + + data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}} + prior = deserialize(data) + # Prior("Normal", mu=0, sigma=1) + + """ + for mapping in DESERIALIZERS: + try: + is_type = mapping.is_type(data) + except Exception: + is_type = False + + if not is_type: + continue + + try: + return mapping.deserialize(data) + except Exception as e: + raise DeserializableError(data) from e + else: + raise DeserializableError(data) + + +def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None: + """Register an arbitrary deserialization. + + Use the :func:`deserialize` function to then deserialize data using all registered + deserialize functions. + + Classes from PyMC-Marketing have their deserialization mappings registered + automatically. However, custom classes will need to be registered manually + using this function before they can be deserialized. + + Parameters + ---------- + is_type : Callable[[Any], bool] + Function to determine if the data is of the correct type. + deserialize : Callable[[dict], Any] + Function to deserialize the data of that type. + + Examples + -------- + Register a custom class deserialization: + + .. code-block:: python + + from pymc_extras.deserialize import register_deserialization + + class MyClass: + def __init__(self, value: int): + self.value = value + + def to_dict(self) -> dict: + # Example of what the to_dict method might look like. + return {"value": self.value} + + register_deserialization( + is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int), + deserialize=lambda data: MyClass(value=data["value"]), + ) + + Use that custom class deserialization: + + .. code-block:: python + + from pymc_extras.deserialize import deserialize + + data = {"value": 42} + obj = deserialize(data) + assert isinstance(obj, MyClass) + + """ + mapping = Deserializer(is_type=is_type, deserialize=deserialize) + DESERIALIZERS.append(mapping) diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py new file mode 100644 index 000000000..715df6874 --- /dev/null +++ b/pymc_extras/prior.py @@ -0,0 +1,1348 @@ +"""Class that represents a prior distribution. + +The `Prior` class is a wrapper around PyMC distributions that allows the user +to create outside of the PyMC model. + +Examples +-------- +Create a normal prior. + +.. code-block:: python + + from pymc_extras.prior import Prior + + normal = Prior("Normal") + +Create a hierarchical normal prior by using distributions for the parameters +and specifying the dims. + +.. code-block:: python + + hierarchical_normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + +Create a non-centered hierarchical normal prior with the `centered` parameter. + +.. code-block:: python + + non_centered_hierarchical_normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + # Only change needed to make it non-centered + centered=False, + ) + +Create a hierarchical beta prior by using Beta distribution, distributions for +the parameters, and specifying the dims. + +.. code-block:: python + + hierarchical_beta = Prior( + "Beta", + alpha=Prior("HalfNormal"), + beta=Prior("HalfNormal"), + dims="channel", + ) + +Create a transformed hierarchical normal prior by using the `transform` +parameter. Here the "sigmoid" transformation comes from `pm.math`. + +.. code-block:: python + + transformed_hierarchical_normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + transform="sigmoid", + dims="channel", + ) + +Create a prior with a custom transform function by registering it with +`register_tensor_transform`. + +.. code-block:: python + + from pymc_extras.prior import register_tensor_transform + + def custom_transform(x): + return x ** 2 + + register_tensor_transform("square", custom_transform) + + custom_distribution = Prior("Normal", transform="square") + +""" + +from __future__ import annotations + +import copy +from collections.abc import Callable +from inspect import signature +from typing import Any, Protocol, runtime_checkable + +import numpy as np +import pymc as pm +import pytensor.tensor as pt +import xarray as xr +from pydantic import InstanceOf, validate_call +from pydantic.dataclasses import dataclass +from pymc.distributions.shape_utils import Dims + +from pymc_extras.deserialize import deserialize, register_deserialization + + +class UnsupportedShapeError(Exception): + """Error for when the shapes from variables are not compatible.""" + + +class UnsupportedDistributionError(Exception): + """Error for when an unsupported distribution is used.""" + + +class UnsupportedParameterizationError(Exception): + """The follow parameterization is not supported.""" + + +class MuAlreadyExistsError(Exception): + """Error for when 'mu' is present in Prior.""" + + def __init__(self, distribution: Prior) -> None: + self.distribution = distribution + self.message = f"The mu parameter is already defined in {distribution}" + super().__init__(self.message) + + +class UnknownTransformError(Exception): + """Error for when an unknown transform is used.""" + + +def _remove_leading_xs(args: list[str | int]) -> list[str | int]: + """Remove leading 'x' from the args.""" + while args and args[0] == "x": + args.pop(0) + + return args + + +def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVariable: + """Take a tensor of dims `dims` and align it to `desired_dims`. + + Doesn't check for validity of the dims + + Examples + -------- + 1D to 2D with new dim + + .. code-block:: python + + x = np.array([1, 2, 3]) + dims = "channel" + + desired_dims = ("channel", "group") + + handle_dims(x, dims, desired_dims) + + """ + x = pt.as_tensor_variable(x) + + if np.ndim(x) == 0: + return x + + dims = dims if isinstance(dims, tuple) else (dims,) + desired_dims = desired_dims if isinstance(desired_dims, tuple) else (desired_dims,) + + if difference := set(dims).difference(desired_dims): + raise UnsupportedShapeError( + f"Dims {dims} of data are not a subset of the desired dims {desired_dims}. " + f"{difference} is missing from the desired dims." + ) + + aligned_dims = np.array(dims)[:, None] == np.array(desired_dims) + + missing_dims = aligned_dims.sum(axis=0) == 0 + new_idx = aligned_dims.argmax(axis=0) + + args = ["x" if missing else idx for (idx, missing) in zip(new_idx, missing_dims, strict=False)] + args = _remove_leading_xs(args) + return x.dimshuffle(*args) + + +DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike] + + +def create_dim_handler(desired_dims: Dims) -> DimHandler: + """Wrap the `handle_dims` function to act like the previous `create_dim_handler` function.""" + + def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable: + return handle_dims(x, dims, desired_dims) + + return func + + +def _dims_to_str(obj: tuple[str, ...]) -> str: + if len(obj) == 1: + return f'"{obj[0]}"' + + return "(" + ", ".join(f'"{i}"' if isinstance(i, str) else str(i) for i in obj) + ")" + + +def _get_pymc_distribution(name: str) -> type[pm.Distribution]: + if not hasattr(pm, name): + raise UnsupportedDistributionError(f"PyMC doesn't have a distribution of name {name!r}") + + return getattr(pm, name) + + +Transform = Callable[[pt.TensorLike], pt.TensorLike] + +CUSTOM_TRANSFORMS: dict[str, Transform] = {} + + +def register_tensor_transform(name: str, transform: Transform) -> None: + """Register a tensor transform function to be used in the `Prior` class. + + Parameters + ---------- + name : str + The name of the transform. + func : Callable[[pt.TensorLike], pt.TensorLike] + The function to apply to the tensor. + + Examples + -------- + Register a custom transform function. + + .. code-block:: python + + from pymc_extras.prior import ( + Prior, + register_tensor_transform, + ) + + def custom_transform(x): + return x ** 2 + + register_tensor_transform("square", custom_transform) + + custom_distribution = Prior("Normal", transform="square") + + """ + CUSTOM_TRANSFORMS[name] = transform + + +def _get_transform(name: str): + if name in CUSTOM_TRANSFORMS: + return CUSTOM_TRANSFORMS[name] + + for module in (pt, pm.math): + if hasattr(module, name): + break + else: + module = None + + if not module: + msg = ( + f"Neither pytensor.tensor nor pymc.math have the function {name!r}. " + "If this is a custom function, register it with the " + "`pymc_extras.prior.register_tensor_transform` function before " + "previous function call." + ) + + raise UnknownTransformError(msg) + + return getattr(module, name) + + +def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]: + return set(signature(distribution.dist).parameters.keys()) - {"kwargs", "args"} + + +@runtime_checkable +class VariableFactory(Protocol): + """Protocol for something that works like a Prior class.""" + + dims: tuple[str, ...] + + def create_variable(self, name: str) -> pt.TensorVariable: + """Create a TensorVariable.""" + + +def sample_prior( + factory: VariableFactory, + coords=None, + name: str = "var", + wrap: bool = False, + **sample_prior_predictive_kwargs, +) -> xr.Dataset: + """Sample the prior for an arbitrary VariableFactory. + + Parameters + ---------- + factory : VariableFactory + The factory to sample from. + coords : dict[str, list[str]], optional + The coordinates for the variable, by default None. + Only required if the dims are specified. + name : str, optional + The name of the variable, by default "var". + wrap : bool, optional + Whether to wrap the variable in a `pm.Deterministic` node, by default False. + sample_prior_predictive_kwargs : dict + Additional arguments to pass to `pm.sample_prior_predictive`. + + Returns + ------- + xr.Dataset + The dataset of the prior samples. + + Example + ------- + Sample from an arbitrary variable factory. + + .. code-block:: python + + import pymc as pm + + import pytensor.tensor as pt + + from pymc_extras.prior import sample_prior + + class CustomVariableDefinition: + def __init__(self, dims, n: int): + self.dims = dims + self.n = n + + def create_variable(self, name: str) -> "TensorVariable": + x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims) + return pt.sum([x ** n for n in range(1, self.n + 1)], axis=0) + + cubic = CustomVariableDefinition(dims=("channel",), n=3) + coords = {"channel": ["C1", "C2", "C3"]} + # Doesn't include the return value + prior = sample_prior(cubic, coords=coords) + + prior_with = sample_prior(cubic, coords=coords, wrap=True) + + """ + coords = coords or {} + + if isinstance(factory.dims, str): + dims = (factory.dims,) + else: + dims = factory.dims + + if missing_keys := set(dims) - set(coords.keys()): + raise KeyError(f"Coords are missing the following dims: {missing_keys}") + + with pm.Model(coords=coords) as model: + if wrap: + pm.Deterministic(name, factory.create_variable(name), dims=factory.dims) + else: + factory.create_variable(name) + + return pm.sample_prior_predictive( + model=model, + **sample_prior_predictive_kwargs, + ).prior + + +class Prior: + """A class to represent a prior distribution. + + Make use of the various helper methods to understand the distributions + better. + + - `preliz` attribute to get the equivalent distribution in `preliz` + - `sample_prior` method to sample from the prior + - `graph` get a dummy model graph with the distribution + - `constrain` to shift the distribution to a different range + + Parameters + ---------- + distribution : str + The name of PyMC distribution. + dims : Dims, optional + The dimensions of the variable, by default None + centered : bool, optional + Whether the variable is centered or not, by default True. + Only allowed for Normal distribution. + transform : str, optional + The name of the transform to apply to the variable after it is + created, by default None or no transform. The transformation must + be registered with `register_tensor_transform` function or + be available in either `pytensor.tensor` or `pymc.math`. + + """ + + # Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family + non_centered_distributions: dict[str, dict[str, float]] = { + "Normal": {"mu": 0, "sigma": 1}, + "StudentT": {"mu": 0, "sigma": 1}, + "ZeroSumNormal": {"sigma": 1}, + } + + pymc_distribution: type[pm.Distribution] + pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None + + @validate_call + def __init__( + self, + distribution: str, + *, + dims: Dims | None = None, + centered: bool = True, + transform: str | None = None, + **parameters, + ) -> None: + self.distribution = distribution + self.parameters = parameters + self.dims = dims + self.centered = centered + self.transform = transform + + self._checks() + + @property + def distribution(self) -> str: + """The name of the PyMC distribution.""" + return self._distribution + + @distribution.setter + def distribution(self, distribution: str) -> None: + if hasattr(self, "_distribution"): + raise AttributeError("Can't change the distribution") + + self._distribution = distribution + self.pymc_distribution = _get_pymc_distribution(distribution) + + @property + def transform(self) -> str | None: + """The name of the transform to apply to the variable after it is created.""" + return self._transform + + @transform.setter + def transform(self, transform: str | None) -> None: + self._transform = transform + self.pytensor_transform = not transform or _get_transform(transform) # type: ignore + + @property + def dims(self) -> Dims: + """The dimensions of the variable.""" + return self._dims + + @dims.setter + def dims(self, dims) -> None: + if isinstance(dims, str): + dims = (dims,) + + self._dims = dims or () + + self._param_dims_work() + self._unique_dims() + + def __getitem__(self, key: str) -> Prior | Any: + """Return the parameter of the prior.""" + return self.parameters[key] + + def _checks(self) -> None: + if not self.centered: + self._correct_non_centered_distribution() + + self._parameters_are_at_least_subset_of_pymc() + self._convert_lists_to_numpy() + self._parameters_are_correct_type() + + def _parameters_are_at_least_subset_of_pymc(self) -> None: + pymc_params = _get_pymc_parameters(self.pymc_distribution) + if not set(self.parameters.keys()).issubset(pymc_params): + msg = ( + f"Parameters {set(self.parameters.keys())} " + "are not a subset of the pymc distribution " + f"parameters {set(pymc_params)}" + ) + raise ValueError(msg) + + def _convert_lists_to_numpy(self) -> None: + def convert(x): + if not isinstance(x, list): + return x + + return np.array(x) + + self.parameters = {key: convert(value) for key, value in self.parameters.items()} + + def _parameters_are_correct_type(self) -> None: + supported_types = ( + int, + float, + np.ndarray, + Prior, + pt.TensorVariable, + VariableFactory, + ) + + incorrect_types = { + param: type(value) + for param, value in self.parameters.items() + if not isinstance(value, supported_types) + } + if incorrect_types: + msg = ( + "Parameters must be one of the following types: " + f"(int, float, np.array, Prior, pt.TensorVariable). Incorrect parameters: {incorrect_types}" + ) + raise ValueError(msg) + + def _correct_non_centered_distribution(self) -> None: + if not self.centered and self.distribution not in self.non_centered_distributions: + raise UnsupportedParameterizationError( + f"{self.distribution!r} is not supported for non-centered parameterization. " + f"Choose from {list(self.non_centered_distributions.keys())}" + ) + + required_parameters = set(self.non_centered_distributions[self.distribution].keys()) + + if set(self.parameters.keys()) < required_parameters: + msg = " and ".join([f"{param!r}" for param in required_parameters]) + raise ValueError( + f"Must have at least {msg} parameter for non-centered for {self.distribution!r}" + ) + + def _unique_dims(self) -> None: + if not self.dims: + return + + if len(self.dims) != len(set(self.dims)): + raise ValueError("Dims must be unique") + + def _param_dims_work(self) -> None: + other_dims = set() + for value in self.parameters.values(): + if hasattr(value, "dims"): + other_dims.update(value.dims) + + if not other_dims.issubset(self.dims): + raise UnsupportedShapeError( + f"Parameter dims {other_dims} are not a subset of the prior dims {self.dims}" + ) + + def __str__(self) -> str: + """Return a string representation of the prior.""" + param_str = ", ".join([f"{param}={value}" for param, value in self.parameters.items()]) + param_str = "" if not param_str else f", {param_str}" + + dim_str = f", dims={_dims_to_str(self.dims)}" if self.dims else "" + centered_str = f", centered={self.centered}" if not self.centered else "" + transform_str = f', transform="{self.transform}"' if self.transform else "" + return f'Prior("{self.distribution}"{param_str}{dim_str}{centered_str}{transform_str})' + + def __repr__(self) -> str: + """Return a string representation of the prior.""" + return f"{self}" + + def _create_parameter(self, param, value, name): + if not hasattr(value, "create_variable"): + return value + + child_name = f"{name}_{param}" + return self.dim_handler(value.create_variable(child_name), value.dims) + + def _create_centered_variable(self, name: str): + parameters = { + param: self._create_parameter(param, value, name) + for param, value in self.parameters.items() + } + return self.pymc_distribution(name, **parameters, dims=self.dims) + + def _create_non_centered_variable(self, name: str) -> pt.TensorVariable: + def handle_variable(var_name: str): + parameter = self.parameters[var_name] + if not hasattr(parameter, "create_variable"): + return parameter + + return self.dim_handler( + parameter.create_variable(f"{name}_{var_name}"), + parameter.dims, + ) + + defaults = self.non_centered_distributions[self.distribution] + other_parameters = { + param: handle_variable(param) + for param in self.parameters.keys() + if param not in defaults + } + offset = self.pymc_distribution( + f"{name}_offset", + **defaults, + **other_parameters, + dims=self.dims, + ) + if "mu" in self.parameters: + mu = ( + handle_variable("mu") + if isinstance(self.parameters["mu"], Prior) + else self.parameters["mu"] + ) + else: + mu = 0 + + sigma = ( + handle_variable("sigma") + if isinstance(self.parameters["sigma"], Prior) + else self.parameters["sigma"] + ) + + return pm.Deterministic( + name, + mu + sigma * offset, + dims=self.dims, + ) + + def create_variable(self, name: str) -> pt.TensorVariable: + """Create a PyMC variable from the prior. + + Must be used in a PyMC model context. + + Parameters + ---------- + name : str + The name of the variable. + + Returns + ------- + pt.TensorVariable + The PyMC variable. + + Examples + -------- + Create a hierarchical normal variable in larger PyMC model. + + .. code-block:: python + + dist = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + + coords = {"channel": ["C1", "C2", "C3"]} + with pm.Model(coords=coords): + var = dist.create_variable("var") + + """ + self.dim_handler = create_dim_handler(self.dims) + + if self.transform: + var_name = f"{name}_raw" + + def transform(var): + return pm.Deterministic(name, self.pytensor_transform(var), dims=self.dims) + else: + var_name = name + + def transform(var): + return var + + create_variable = ( + self._create_centered_variable if self.centered else self._create_non_centered_variable + ) + var = create_variable(name=var_name) + return transform(var) + + @property + def preliz(self): + """Create an equivalent preliz distribution. + + Helpful to visualize a distribution when it is univariate. + + Returns + ------- + preliz.distributions.Distribution + + Examples + -------- + Create a preliz distribution from a prior. + + .. code-block:: python + + from pymc_extras.prior import Prior + + dist = Prior("Gamma", alpha=5, beta=1) + dist.preliz.plot_pdf() + + """ + import preliz as pz + + return getattr(pz, self.distribution)(**self.parameters) + + def to_dict(self) -> dict[str, Any]: + """Convert the prior to dictionary format. + + Returns + ------- + dict[str, Any] + The dictionary format of the prior. + + Examples + -------- + Convert a prior to the dictionary format. + + .. code-block:: python + + from pymc_extras.prior import Prior + + dist = Prior("Normal", mu=0, sigma=1) + + dist.to_dict() + # {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}} + + Convert a hierarchical prior to the dictionary format. + + .. code-block:: python + + dist = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + + dist.to_dict() + # { + # "dist": "Normal", + # "kwargs": { + # "mu": {"dist": "Normal"}, + # "sigma": {"dist": "HalfNormal"}, + # }, + # "dims": "channel", + # } + + """ + data: dict[str, Any] = { + "dist": self.distribution, + } + if self.parameters: + + def handle_value(value): + if isinstance(value, Prior): + return value.to_dict() + + if isinstance(value, pt.TensorVariable): + value = value.eval() + + if isinstance(value, np.ndarray): + return value.tolist() + + if hasattr(value, "to_dict"): + return value.to_dict() + + return value + + data["kwargs"] = { + param: handle_value(value) for param, value in self.parameters.items() + } + if not self.centered: + data["centered"] = False + + if self.dims: + data["dims"] = self.dims + + if self.transform: + data["transform"] = self.transform + + return data + + @classmethod + def from_dict(cls, data) -> Prior: + """Create a Prior from the dictionary format. + + Parameters + ---------- + data : dict[str, Any] + The dictionary format of the prior. + + Returns + ------- + Prior + The prior distribution. + + Examples + -------- + Convert prior in the dictionary format to a Prior instance. + + .. code-block:: python + + from pymc_extras.prior import Prior + + data = { + "dist": "Normal", + "kwargs": {"mu": 0, "sigma": 1}, + } + + dist = Prior.from_dict(data) + dist + # Prior("Normal", mu=0, sigma=1) + + """ + if not isinstance(data, dict): + msg = ( + "Must be a dictionary representation of a prior distribution. " + f"Not of type: {type(data)}" + ) + raise ValueError(msg) + + dist = data["dist"] + kwargs = data.get("kwargs", {}) + + def handle_value(value): + if isinstance(value, dict): + return deserialize(value) + + if isinstance(value, list): + return np.array(value) + + return value + + kwargs = {param: handle_value(value) for param, value in kwargs.items()} + centered = data.get("centered", True) + dims = data.get("dims") + if isinstance(dims, list): + dims = tuple(dims) + transform = data.get("transform") + + return cls(dist, dims=dims, centered=centered, transform=transform, **kwargs) + + def constrain(self, lower: float, upper: float, mass: float = 0.95, kwargs=None) -> Prior: + """Create a new prior with a given mass constrained within the given bounds. + + Wrapper around `preliz.maxent`. + + Parameters + ---------- + lower : float + The lower bound. + upper : float + The upper bound. + mass: float = 0.95 + The mass of the distribution to keep within the bounds. + kwargs : dict + Additional arguments to pass to `pz.maxent`. + + Returns + ------- + Prior + The maximum entropy prior with a mass constrained to the given bounds. + + Examples + -------- + Create a Beta distribution that is constrained to have 95% of the mass + between 0.5 and 0.8. + + .. code-block:: python + + dist = Prior( + "Beta", + ).constrain(lower=0.5, upper=0.8) + + Create a Beta distribution with mean 0.6, that is constrained to + have 95% of the mass between 0.5 and 0.8. + + .. code-block:: python + + dist = Prior( + "Beta", + mu=0.6, + ).constrain(lower=0.5, upper=0.8) + + """ + from preliz import maxent + + if self.transform: + raise ValueError("Can't constrain a transformed variable") + + if kwargs is None: + kwargs = {} + kwargs.setdefault("plot", False) + + if kwargs["plot"]: + new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs)[0].params_dict + else: + new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs).params_dict + + return Prior( + self.distribution, + dims=self.dims, + transform=self.transform, + centered=self.centered, + **new_parameters, + ) + + def __eq__(self, other) -> bool: + """Check if two priors are equal.""" + if not isinstance(other, Prior): + return False + + try: + np.testing.assert_equal(self.parameters, other.parameters) + except AssertionError: + return False + + return ( + self.distribution == other.distribution + and self.dims == other.dims + and self.centered == other.centered + and self.transform == other.transform + ) + + def sample_prior( + self, + coords=None, + name: str = "var", + **sample_prior_predictive_kwargs, + ) -> xr.Dataset: + """Sample the prior distribution for the variable. + + Parameters + ---------- + coords : dict[str, list[str]], optional + The coordinates for the variable, by default None. + Only required if the dims are specified. + name : str, optional + The name of the variable, by default "var". + sample_prior_predictive_kwargs : dict + Additional arguments to pass to `pm.sample_prior_predictive`. + + Returns + ------- + xr.Dataset + The dataset of the prior samples. + + Example + ------- + Sample from a hierarchical normal distribution. + + .. code-block:: python + + dist = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + + coords = {"channel": ["C1", "C2", "C3"]} + prior = dist.sample_prior(coords=coords) + + """ + return sample_prior( + factory=self, + coords=coords, + name=name, + **sample_prior_predictive_kwargs, + ) + + def __deepcopy__(self, memo) -> Prior: + """Return a deep copy of the prior.""" + if id(self) in memo: + return memo[id(self)] + + copy_obj = Prior( + self.distribution, + dims=copy.copy(self.dims), + centered=self.centered, + transform=self.transform, + **copy.deepcopy(self.parameters), + ) + memo[id(self)] = copy_obj + return copy_obj + + def deepcopy(self) -> Prior: + """Return a deep copy of the prior.""" + return copy.deepcopy(self) + + def to_graph(self): + """Generate a graph of the variables. + + Examples + -------- + Create the graph for a 2D transformed hierarchical distribution. + + .. code-block:: python + + from pymc_extras.prior import Prior + + mu = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + sigma = Prior("HalfNormal", dims="channel") + dist = Prior( + "Normal", + mu=mu, + sigma=sigma, + dims=("channel", "geo"), + centered=False, + transform="sigmoid", + ) + + dist.to_graph() + + .. image:: /_static/example-graph.png + :alt: Example graph + + """ + coords = {name: ["DUMMY"] for name in self.dims} + with pm.Model(coords=coords) as model: + self.create_variable("var") + + return pm.model_to_graphviz(model) + + def create_likelihood_variable( + self, + name: str, + mu: pt.TensorLike, + observed: pt.TensorLike, + ) -> pt.TensorVariable: + """Create a likelihood variable from the prior. + + Will require that the distribution has a `mu` parameter + and that it has not been set in the parameters. + + Parameters + ---------- + name : str + The name of the variable. + mu : pt.TensorLike + The mu parameter for the likelihood. + observed : pt.TensorLike + The observed data. + + Returns + ------- + pt.TensorVariable + The PyMC variable. + + Examples + -------- + Create a likelihood variable in a larger PyMC model. + + .. code-block:: python + + import pymc as pm + + dist = Prior("Normal", sigma=Prior("HalfNormal")) + + with pm.Model(): + # Create the likelihood variable + mu = pm.Normal("mu", mu=0, sigma=1) + dist.create_likelihood_variable("y", mu=mu, observed=observed) + + """ + if "mu" not in _get_pymc_parameters(self.pymc_distribution): + raise UnsupportedDistributionError( + f"Likelihood distribution {self.distribution!r} is not supported." + ) + + if "mu" in self.parameters: + raise MuAlreadyExistsError(self) + + distribution = self.deepcopy() + distribution.parameters["mu"] = mu + distribution.parameters["observed"] = observed + return distribution.create_variable(name) + + +class VariableNotFound(Exception): + """Variable is not found.""" + + +def _remove_random_variable(var: pt.TensorVariable) -> None: + if var.name is None: + raise ValueError("This isn't removable") + + name: str = var.name + + model = pm.modelcontext(None) + for idx, free_rv in enumerate(model.free_RVs): + if var == free_rv: + index_to_remove = idx + break + else: + raise VariableNotFound(f"Variable {var.name!r} not found") + + var.name = None + model.free_RVs.pop(index_to_remove) + model.named_vars.pop(name) + + +@dataclass +class Censored: + """Create censored random variable. + + Examples + -------- + Create a censored Normal distribution: + + .. code-block:: python + + from pymc_extras.prior import Prior, Censored + + normal = Prior("Normal") + censored_normal = Censored(normal, lower=0) + + Create hierarchical censored Normal distribution: + + .. code-block:: python + + from pymc_extras.prior import Prior, Censored + + normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + censored_normal = Censored(normal, lower=0) + + coords = {"channel": range(3)} + samples = censored_normal.sample_prior(coords=coords) + + """ + + distribution: InstanceOf[Prior] + lower: float | InstanceOf[pt.TensorVariable] = -np.inf + upper: float | InstanceOf[pt.TensorVariable] = np.inf + + def __post_init__(self) -> None: + """Check validity at initialization.""" + if not self.distribution.centered: + raise ValueError( + "Censored distribution must be centered so that .dist() API can be used on distribution." + ) + + if self.distribution.transform is not None: + raise ValueError( + "Censored distribution can't have a transform so that .dist() API can be used on distribution." + ) + + @property + def dims(self) -> tuple[str, ...]: + """The dims from the distribution to censor.""" + return self.distribution.dims + + @dims.setter + def dims(self, dims) -> None: + self.distribution.dims = dims + + def create_variable(self, name: str) -> pt.TensorVariable: + """Create censored random variable.""" + dist = self.distribution.create_variable(name) + _remove_random_variable(var=dist) + + return pm.Censored( + name, + dist, + lower=self.lower, + upper=self.upper, + dims=self.dims, + ) + + def to_dict(self) -> dict[str, Any]: + """Convert the censored distribution to a dictionary.""" + + def handle_value(value): + if isinstance(value, pt.TensorVariable): + return value.eval().tolist() + + return value + + return { + "class": "Censored", + "data": { + "dist": self.distribution.to_dict(), + "lower": handle_value(self.lower), + "upper": handle_value(self.upper), + }, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Censored: + """Create a censored distribution from a dictionary.""" + data = data["data"] + return cls( # type: ignore + distribution=Prior.from_dict(data["dist"]), + lower=data["lower"], + upper=data["upper"], + ) + + def sample_prior( + self, + coords=None, + name: str = "variable", + **sample_prior_predictive_kwargs, + ) -> xr.Dataset: + """Sample the prior distribution for the variable. + + Parameters + ---------- + coords : dict[str, list[str]], optional + The coordinates for the variable, by default None. + Only required if the dims are specified. + name : str, optional + The name of the variable, by default "var". + sample_prior_predictive_kwargs : dict + Additional arguments to pass to `pm.sample_prior_predictive`. + + Returns + ------- + xr.Dataset + The dataset of the prior samples. + + Example + ------- + Sample from a censored Gamma distribution. + + .. code-block:: python + + gamma = Prior("Gamma", mu=1, sigma=1, dims="channel") + dist = Censored(gamma, lower=0.5) + + coords = {"channel": ["C1", "C2", "C3"]} + prior = dist.sample_prior(coords=coords) + + """ + return sample_prior( + factory=self, + coords=coords, + name=name, + **sample_prior_predictive_kwargs, + ) + + def to_graph(self): + """Generate a graph of the variables. + + Examples + -------- + Create graph for a censored Normal distribution + + .. code-block:: python + + from pymc_extras.prior import Prior, Censored + + normal = Prior("Normal") + censored_normal = Censored(normal, lower=0) + + censored_normal.to_graph() + + """ + coords = {name: ["DUMMY"] for name in self.dims} + with pm.Model(coords=coords) as model: + self.create_variable("var") + + return pm.model_to_graphviz(model) + + def create_likelihood_variable( + self, + name: str, + mu: pt.TensorLike, + observed: pt.TensorLike, + ) -> pt.TensorVariable: + """Create observed censored variable. + + Will require that the distribution has a `mu` parameter + and that it has not been set in the parameters. + + Parameters + ---------- + name : str + The name of the variable. + mu : pt.TensorLike + The mu parameter for the likelihood. + observed : pt.TensorLike + The observed data. + + Returns + ------- + pt.TensorVariable + The PyMC variable. + + Examples + -------- + Create a censored likelihood variable in a larger PyMC model. + + .. code-block:: python + + import pymc as pm + from pymc_extras.prior import Prior, Censored + + normal = Prior("Normal", sigma=Prior("HalfNormal")) + dist = Censored(normal, lower=0) + + observed = 1 + + with pm.Model(): + # Create the likelihood variable + mu = pm.HalfNormal("mu", sigma=1) + dist.create_likelihood_variable("y", mu=mu, observed=observed) + + """ + if "mu" not in _get_pymc_parameters(self.distribution.pymc_distribution): + raise UnsupportedDistributionError( + f"Likelihood distribution {self.distribution.distribution!r} is not supported." + ) + + if "mu" in self.distribution.parameters: + raise MuAlreadyExistsError(self.distribution) + + distribution = self.distribution.deepcopy() + distribution.parameters["mu"] = mu + + dist = distribution.create_variable(name) + _remove_random_variable(var=dist) + + return pm.Censored( + name, + dist, + observed=observed, + lower=self.lower, + upper=self.upper, + dims=self.dims, + ) + + +class Scaled: + """Scaled distribution for numerical stability.""" + + def __init__(self, dist: Prior, factor: float | pt.TensorVariable) -> None: + self.dist = dist + self.factor = factor + + @property + def dims(self) -> Dims: + """The dimensions of the scaled distribution.""" + return self.dist.dims + + def create_variable(self, name: str) -> pt.TensorVariable: + """Create a scaled variable. + + Parameters + ---------- + name : str + The name of the variable. + + Returns + ------- + pt.TensorVariable + The scaled variable. + """ + var = self.dist.create_variable(f"{name}_unscaled") + return pm.Deterministic(name, var * self.factor, dims=self.dims) From 9dfc00ac6f28bdad00af4e0bac9e1c5486ff8c51 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 4 May 2025 09:42:16 -0400 Subject: [PATCH 02/15] port tests --- tests/test_deserialize.py | 59 ++ tests/test_prior.py | 1144 +++++++++++++++++++++++++++++++++++++ 2 files changed, 1203 insertions(+) create mode 100644 tests/test_deserialize.py create mode 100644 tests/test_prior.py diff --git a/tests/test_deserialize.py b/tests/test_deserialize.py new file mode 100644 index 000000000..2fcc28a13 --- /dev/null +++ b/tests/test_deserialize.py @@ -0,0 +1,59 @@ +import pytest + +from pymc_extras.deserialize import ( + DESERIALIZERS, + DeserializableError, + deserialize, + register_deserialization, +) + + +@pytest.mark.parametrize( + "unknown_data", + [ + {"unknown": 1}, + {"dist": "Normal", "kwargs": {"something": "else"}}, + 1, + ], + ids=["unknown_structure", "prior_like", "non_dict"], +) +def test_unknown_type_raises(unknown_data) -> None: + match = "Couldn't deserialize" + with pytest.raises(DeserializableError, match=match): + deserialize(unknown_data) + + +class ArbitraryObject: + def __init__(self, code: str): + self.code = code + self.value = 1 + + +@pytest.fixture +def register_arbitrary_object(): + register_deserialization( + is_type=lambda data: data.keys() == {"code"}, + deserialize=lambda data: ArbitraryObject(code=data["code"]), + ) + + yield + + DESERIALIZERS.pop() + + +def test_registration(register_arbitrary_object) -> None: + instance = deserialize({"code": "test"}) + + assert isinstance(instance, ArbitraryObject) + assert instance.code == "test" + + +def test_registeration_mixup() -> None: + data_that_looks_like_prior = { + "dist": "Normal", + "kwargs": {"something": "else"}, + } + + match = "Couldn't deserialize" + with pytest.raises(DeserializableError, match=match): + deserialize(data_that_looks_like_prior) diff --git a/tests/test_prior.py b/tests/test_prior.py new file mode 100644 index 000000000..d0a3490ea --- /dev/null +++ b/tests/test_prior.py @@ -0,0 +1,1144 @@ +from copy import deepcopy +from typing import NamedTuple + +import numpy as np +import pymc as pm +import pytensor.tensor as pt +import pytest +import xarray as xr + +from graphviz.graphs import Digraph +from pydantic import ValidationError +from pymc.model_graph import fast_eval + +from pymc_extras.deserialize import ( + DESERIALIZERS, + deserialize, + register_deserialization, +) +from pymc_extras.prior import ( + Censored, + MuAlreadyExistsError, + Prior, + Scaled, + UnknownTransformError, + UnsupportedDistributionError, + UnsupportedParameterizationError, + UnsupportedShapeError, + VariableFactory, + handle_dims, + register_tensor_transform, + sample_prior, +) + +pz = pytest.importorskip("preliz") + + +@pytest.mark.parametrize( + "x, dims, desired_dims, expected_fn", + [ + (np.arange(3), "channel", "channel", lambda x: x), + (np.arange(3), "channel", ("geo", "channel"), lambda x: x), + (np.arange(3), "channel", ("channel", "geo"), lambda x: x[:, None]), + (np.arange(3), "channel", ("x", "y", "channel", "geo"), lambda x: x[:, None]), + ( + np.arange(3 * 2).reshape(3, 2), + ("channel", "geo"), + ("geo", "x", "y", "channel"), + lambda x: x.T[:, None, None, :], + ), + ( + np.arange(4 * 2 * 3).reshape(4, 2, 3), + ("channel", "geo", "store"), + ("geo", "x", "store", "channel"), + lambda x: x.swapaxes(0, 2).swapaxes(0, 1)[:, None, :, :], + ), + ], + ids=[ + "same_dims", + "different_dims", + "dim_padding", + "just_enough_dims", + "transpose_and_padding", + "swaps_and_padding", + ], +) +def test_handle_dims(x, dims, desired_dims, expected_fn) -> None: + result = handle_dims(x, dims, desired_dims) + if isinstance(result, pt.TensorVariable): + result = fast_eval(result) + + np.testing.assert_array_equal(result, expected_fn(x)) + + +@pytest.mark.parametrize( + "x, dims, desired_dims", + [ + (np.ones(3), "channel", "something_else"), + (np.ones((3, 2)), ("a", "b"), ("a", "B")), + ], + ids=["no_incommon", "some_incommon"], +) +def test_handle_dims_with_impossible_dims(x, dims, desired_dims) -> None: + match = " are not a subset of the desired dims " + with pytest.raises(UnsupportedShapeError, match=match): + handle_dims(x, dims, desired_dims) + + +def test_missing_transform() -> None: + match = "Neither pytensor.tensor nor pymc.math have the function 'foo_bar'" + with pytest.raises(UnknownTransformError, match=match): + Prior("Normal", transform="foo_bar") + + +def test_get_item() -> None: + var = Prior("Normal", mu=0, sigma=1) + + assert var["mu"] == 0 + assert var["sigma"] == 1 + + +def test_noncentered_needs_params() -> None: + with pytest.raises(ValueError): + Prior( + "Normal", + centered=False, + ) + + +def test_different_than_pymc_params() -> None: + with pytest.raises(ValueError): + Prior("Normal", mu=0, b=1) + + +def test_non_unique_dims() -> None: + with pytest.raises(ValueError): + Prior("Normal", mu=0, sigma=1, dims=("channel", "channel")) + + +def test_doesnt_check_validity_parameterization() -> None: + try: + Prior("Normal", mu=0, sigma=1, tau=1) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +def test_doesnt_check_validity_values() -> None: + try: + Prior("Normal", mu=0, sigma=-1) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +def test_preliz() -> None: + var = Prior("Normal", mu=0, sigma=1) + dist = var.preliz + assert isinstance(dist, pz.distributions.Distribution) + + +@pytest.mark.parametrize( + "var, expected", + [ + (Prior("Normal", mu=0, sigma=1), 'Prior("Normal", mu=0, sigma=1)'), + ( + Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal")), + 'Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"))', + ), + (Prior("Normal", dims="channel"), 'Prior("Normal", dims="channel")'), + ( + Prior("Normal", mu=0, sigma=1, transform="sigmoid"), + 'Prior("Normal", mu=0, sigma=1, transform="sigmoid")', + ), + ], +) +def test_str(var, expected) -> None: + assert str(var) == expected + + +@pytest.mark.parametrize( + "var", + [ + Prior("Normal", mu=0, sigma=1), + Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"), dims="channel"), + Prior("Normal", dims=("geo", "channel")), + ], +) +def test_repr(var) -> None: + assert eval(repr(var)) == var + + +def test_invalid_distribution() -> None: + with pytest.raises(UnsupportedDistributionError): + Prior("Invalid") + + +def test_broadcast_doesnt_work(): + with pytest.raises(UnsupportedShapeError): + Prior( + "Normal", + mu=0, + sigma=Prior("HalfNormal", sigma=1, dims="x"), + dims="y", + ) + + +def test_dim_workaround_flaw() -> None: + distribution = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="y", + ) + + try: + distribution["mu"].dims = "x" + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + with pytest.raises(UnsupportedShapeError): + distribution._param_dims_work() + + +def test_noncentered_error() -> None: + with pytest.raises(UnsupportedParameterizationError): + Prior( + "Gamma", + mu=0, + sigma=1, + dims="x", + centered=False, + ) + + +def test_create_variable_multiple_times() -> None: + mu = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + centered=False, + ) + + coords = { + "channel": ["a", "b", "c"], + } + with pm.Model(coords=coords) as model: + mu.create_variable("mu") + mu.create_variable("mu_2") + + suffixes = [ + "", + "_offset", + "_mu", + "_sigma", + ] + dims = [(3,), (3,), (), ()] + + for prefix in ["mu", "mu_2"]: + for suffix, dim in zip(suffixes, dims, strict=False): + assert fast_eval(model[f"{prefix}{suffix}"]).shape == dim + + +@pytest.fixture +def large_var() -> Prior: + mu = Prior( + "Normal", + mu=Prior("Normal", mu=1), + sigma=Prior("HalfNormal"), + dims="channel", + centered=False, + ) + sigma = Prior("HalfNormal", sigma=Prior("HalfNormal"), dims="geo") + + return Prior("Normal", mu=mu, sigma=sigma, dims=("geo", "channel")) + + +def test_create_variable(large_var) -> None: + coords = { + "channel": ["a", "b", "c"], + "geo": ["x", "y"], + } + with pm.Model(coords=coords) as model: + large_var.create_variable("var") + + var_names = [ + "var", + "var_mu", + "var_sigma", + "var_mu_offset", + "var_mu_mu", + "var_mu_sigma", + "var_sigma_sigma", + ] + assert set(var.name for var in model.unobserved_RVs) == set(var_names) + dims = [ + (2, 3), + (3,), + (2,), + (3,), + (), + (), + (), + ] + for var_name, dim in zip(var_names, dims, strict=False): + assert fast_eval(model[var_name]).shape == dim + + +def test_transform() -> None: + var = Prior("Normal", mu=0, sigma=1, transform="sigmoid") + + with pm.Model() as model: + var.create_variable("var") + + var_names = [ + "var", + "var_raw", + ] + dims = [ + (), + (), + ] + for var_name, dim in zip(var_names, dims, strict=False): + assert fast_eval(model[var_name]).shape == dim + + +def test_to_dict(large_var) -> None: + data = large_var.to_dict() + + assert data == { + "dist": "Normal", + "kwargs": { + "mu": { + "dist": "Normal", + "kwargs": { + "mu": { + "dist": "Normal", + "kwargs": { + "mu": 1, + }, + }, + "sigma": { + "dist": "HalfNormal", + }, + }, + "centered": False, + "dims": ("channel",), + }, + "sigma": { + "dist": "HalfNormal", + "kwargs": { + "sigma": { + "dist": "HalfNormal", + }, + }, + "dims": ("geo",), + }, + }, + "dims": ("geo", "channel"), + } + + +def test_to_dict_numpy() -> None: + var = Prior("Normal", mu=np.array([0, 10, 20]), dims="channel") + assert var.to_dict() == { + "dist": "Normal", + "kwargs": { + "mu": [0, 10, 20], + }, + "dims": ("channel",), + } + + +def test_dict_round_trip(large_var) -> None: + assert Prior.from_dict(large_var.to_dict()) == large_var + + +def test_constrain_with_transform_error() -> None: + var = Prior("Normal", transform="sigmoid") + + with pytest.raises(ValueError): + var.constrain(lower=0, upper=1) + + +def test_constrain(mocker) -> None: + var = Prior("Normal") + + mocker.patch( + "preliz.maxent", + return_value=mocker.Mock(params_dict={"mu": 5, "sigma": 2}), + ) + + new_var = var.constrain(lower=0, upper=1) + assert new_var == Prior("Normal", mu=5, sigma=2) + + +def test_dims_change() -> None: + var = Prior("Normal", mu=0, sigma=1) + var.dims = "channel" + + assert var.dims == ("channel",) + + +def test_dims_change_error() -> None: + mu = Prior("Normal", dims="channel") + var = Prior("Normal", mu=mu, dims="channel") + + with pytest.raises(UnsupportedShapeError): + var.dims = "geo" + + +def test_deepcopy() -> None: + priors = { + "alpha": Prior("Beta", alpha=1, beta=1), + "gamma": Prior("Normal", mu=0, sigma=1), + } + + new_priors = deepcopy(priors) + priors["alpha"].dims = "channel" + + assert new_priors["alpha"].dims == () + + +@pytest.fixture +def mmm_default_model_config(): + return { + "intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, + "likelihood": { + "dist": "Normal", + "kwargs": { + "sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}}, + }, + }, + "gamma_control": { + "dist": "Normal", + "kwargs": {"mu": 0, "sigma": 2}, + "dims": "control", + }, + "gamma_fourier": { + "dist": "Laplace", + "kwargs": {"mu": 0, "b": 1}, + "dims": "fourier_mode", + }, + } + + +def test_backwards_compat(mmm_default_model_config) -> None: + result = {param: Prior.from_dict(value) for param, value in mmm_default_model_config.items()} + assert result == { + "intercept": Prior("Normal", mu=0, sigma=2), + "likelihood": Prior("Normal", sigma=Prior("HalfNormal", sigma=2)), + "gamma_control": Prior("Normal", mu=0, sigma=2, dims="control"), + "gamma_fourier": Prior("Laplace", mu=0, b=1, dims="fourier_mode"), + } + + +def test_sample_prior() -> None: + var = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + transform="sigmoid", + ) + + coords = {"channel": ["A", "B", "C"]} + prior = var.sample_prior(coords=coords, samples=25) + + assert isinstance(prior, xr.Dataset) + assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} + + +def test_sample_prior_missing_coords() -> None: + dist = Prior("Normal", dims="channel") + + with pytest.raises(KeyError, match="Coords"): + dist.sample_prior() + + +def test_to_graph() -> None: + hierarchical_distribution = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + + G = hierarchical_distribution.to_graph() + assert isinstance(G, Digraph) + + +def test_from_dict_list() -> None: + data = { + "dist": "Normal", + "kwargs": { + "mu": [0, 1, 2], + "sigma": 1, + }, + "dims": "channel", + } + + var = Prior.from_dict(data) + assert var.dims == ("channel",) + assert isinstance(var["mu"], np.ndarray) + np.testing.assert_array_equal(var["mu"], [0, 1, 2]) + + +def test_from_dict_list_dims() -> None: + data = { + "dist": "Normal", + "kwargs": { + "mu": 0, + "sigma": 1, + }, + "dims": ["channel", "geo"], + } + + var = Prior.from_dict(data) + assert var.dims == ("channel", "geo") + + +def test_to_dict_transform() -> None: + dist = Prior("Normal", transform="sigmoid") + + data = dist.to_dict() + assert data == { + "dist": "Normal", + "transform": "sigmoid", + } + + +def test_equality_non_prior() -> None: + dist = Prior("Normal") + + assert dist != 1 + + +def test_deepcopy_memo() -> None: + memo = {} + dist = Prior("Normal") + memo[id(dist)] = dist + deepcopy(dist, memo) + assert len(memo) == 1 + deepcopy(dist, memo) + + assert len(memo) == 1 + + +def test_create_likelihood_variable() -> None: + distribution = Prior("Normal", sigma=Prior("HalfNormal")) + + with pm.Model() as model: + mu = pm.Normal("mu") + + data = distribution.create_likelihood_variable("data", mu=mu, observed=10) + + assert model.observed_RVs == [data] + assert "data_sigma" in model + + +def test_create_likelihood_variable_already_has_mu() -> None: + distribution = Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal")) + + with pm.Model(): + mu = pm.Normal("mu") + + with pytest.raises(MuAlreadyExistsError): + distribution.create_likelihood_variable("data", mu=mu, observed=10) + + +def test_create_likelihood_non_mu_parameterized_distribution() -> None: + distribution = Prior("Cauchy") + + with pm.Model(): + mu = pm.Normal("mu") + with pytest.raises(UnsupportedDistributionError): + distribution.create_likelihood_variable("data", mu=mu, observed=10) + + +def test_non_centered_student_t() -> None: + try: + Prior( + "StudentT", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + nu=Prior("HalfNormal"), + dims="channel", + centered=False, + ) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +def test_cant_reset_distribution() -> None: + dist = Prior("Normal") + with pytest.raises(AttributeError, match="Can't change the distribution"): + dist.distribution = "Cauchy" + + +def test_nonstring_distribution() -> None: + with pytest.raises(ValidationError, match=".*Input should be a valid string.*"): + Prior(pm.Normal) + + +def test_change_the_transform() -> None: + dist = Prior("Normal") + dist.transform = "logit" + assert dist.transform == "logit" + + +def test_nonstring_transform() -> None: + with pytest.raises(ValidationError, match=".*Input should be a valid string.*"): + Prior("Normal", transform=pm.math.log) + + +def test_checks_param_value_types() -> None: + with pytest.raises(ValueError, match="Parameters must be one of the following types"): + Prior("Normal", mu="str", sigma="str") + + +def test_check_equality_with_numpy() -> None: + dist = Prior("Normal", mu=np.array([1, 2, 3]), sigma=1) + assert dist == dist.deepcopy() + + +def clear_custom_transforms() -> None: + global CUSTOM_TRANSFORMS + CUSTOM_TRANSFORMS = {} + + +def test_custom_transform() -> None: + new_transform_name = "foo_bar" + with pytest.raises(UnknownTransformError): + Prior("Normal", transform=new_transform_name) + + register_tensor_transform(new_transform_name, lambda x: x**2) + + dist = Prior("Normal", transform=new_transform_name) + prior = dist.sample_prior(samples=10) + df_prior = prior.to_dataframe() + + np.testing.assert_array_equal(df_prior["var"].to_numpy(), df_prior["var_raw"].to_numpy() ** 2) + + +def test_custom_transform_comes_first() -> None: + # function in pytensor.tensor + register_tensor_transform("square", lambda x: 2 * x) + + dist = Prior("Normal", transform="square") + prior = dist.sample_prior(samples=10) + df_prior = prior.to_dataframe() + + np.testing.assert_array_equal(df_prior["var"].to_numpy(), 2 * df_prior["var_raw"].to_numpy()) + + clear_custom_transforms() + + +def test_serialize_with_pytensor() -> None: + sigma = pt.arange(1, 4) + dist = Prior("Normal", mu=0, sigma=sigma) + + assert dist.to_dict() == { + "dist": "Normal", + "kwargs": { + "mu": 0, + "sigma": [1, 2, 3], + }, + } + + +def test_zsn_non_centered() -> None: + try: + Prior("ZeroSumNormal", sigma=1, centered=False) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +class Arbitrary: + def __init__(self, dims: str | tuple[str, ...]) -> None: + self.dims = dims + + def create_variable(self, name: str): + return pm.Normal(name, dims=self.dims) + + +class ArbitraryWithoutName: + def __init__(self, dims: str | tuple[str, ...]) -> None: + self.dims = dims + + def create_variable(self, name: str): + with pm.Model(name=name): + location = pm.Normal("location", dims=self.dims) + scale = pm.HalfNormal("scale", dims=self.dims) + + return pm.Normal("standard_normal") * scale + location + + +def test_sample_prior_arbitrary() -> None: + var = Arbitrary(dims="channel") + + prior = sample_prior(var, coords={"channel": ["A", "B", "C"]}, draws=25) + + assert isinstance(prior, xr.Dataset) + + +def test_sample_prior_arbitrary_no_name() -> None: + var = ArbitraryWithoutName(dims="channel") + + prior = sample_prior(var, coords={"channel": ["A", "B", "C"]}, draws=25) + + assert isinstance(prior, xr.Dataset) + assert "var" not in prior + + prior_with = sample_prior( + var, + coords={"channel": ["A", "B", "C"]}, + draws=25, + wrap=True, + ) + + assert isinstance(prior_with, xr.Dataset) + assert "var" in prior_with + + +def test_create_prior_with_arbitrary() -> None: + dist = Prior( + "Normal", + mu=Arbitrary(dims=("channel",)), + sigma=1, + dims=("channel", "geo"), + ) + + coords = { + "channel": ["C1", "C2", "C3"], + "geo": ["G1", "G2"], + } + with pm.Model(coords=coords) as model: + dist.create_variable("var") + + assert "var_mu" in model + var_mu = model["var_mu"] + + assert fast_eval(var_mu).shape == (len(coords["channel"]),) + + +def test_censored_is_variable_factory() -> None: + normal = Prior("Normal") + censored_normal = Censored(normal, lower=0) + + assert isinstance(censored_normal, VariableFactory) + + +@pytest.mark.parametrize( + "dims, expected_dims", + [ + ("channel", ("channel",)), + (("channel", "geo"), ("channel", "geo")), + ], + ids=["string", "tuple"], +) +def test_censored_dims_from_distribution(dims, expected_dims) -> None: + normal = Prior("Normal", dims=dims) + censored_normal = Censored(normal, lower=0) + + assert censored_normal.dims == expected_dims + + +def test_censored_variables_created() -> None: + normal = Prior("Normal", mu=Prior("Normal"), dims="dim") + censored_normal = Censored(normal, lower=0) + + coords = {"dim": range(3)} + with pm.Model(coords=coords) as model: + censored_normal.create_variable("var") + + var_names = ["var", "var_mu"] + assert set(var.name for var in model.unobserved_RVs) == set(var_names) + dims = [(3,), ()] + for var_name, dim in zip(var_names, dims, strict=False): + assert fast_eval(model[var_name]).shape == dim + + +def test_censored_sample_prior() -> None: + normal = Prior("Normal", dims="channel") + censored_normal = Censored(normal, lower=0) + + coords = {"channel": ["A", "B", "C"]} + prior = censored_normal.sample_prior(coords=coords, samples=25) + + assert isinstance(prior, xr.Dataset) + assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} + + +def test_censored_to_graph() -> None: + normal = Prior("Normal", dims="channel") + censored_normal = Censored(normal, lower=0) + + G = censored_normal.to_graph() + assert isinstance(G, Digraph) + + +def test_censored_likelihood_variable() -> None: + normal = Prior("Normal", sigma=Prior("HalfNormal"), dims="channel") + censored_normal = Censored(normal, lower=0) + + coords = {"channel": range(3)} + with pm.Model(coords=coords) as model: + mu = pm.Normal("mu") + variable = censored_normal.create_likelihood_variable( + name="likelihood", + mu=mu, + observed=[1, 2, 3], + ) + + assert isinstance(variable, pt.TensorVariable) + assert model.observed_RVs == [variable] + assert "likelihood_sigma" in model + + +def test_censored_likelihood_unsupported_distribution() -> None: + cauchy = Prior("Cauchy") + censored_cauchy = Censored(cauchy, lower=0) + + with pm.Model(): + mu = pm.Normal("mu") + with pytest.raises(UnsupportedDistributionError): + censored_cauchy.create_likelihood_variable( + name="likelihood", + mu=mu, + observed=1, + ) + + +def test_censored_likelihood_already_has_mu() -> None: + normal = Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal")) + censored_normal = Censored(normal, lower=0) + + with pm.Model(): + mu = pm.Normal("mu") + with pytest.raises(MuAlreadyExistsError): + censored_normal.create_likelihood_variable( + name="likelihood", + mu=mu, + observed=1, + ) + + +def test_censored_to_dict() -> None: + normal = Prior("Normal", mu=0, sigma=1, dims="channel") + censored_normal = Censored(normal, lower=0) + + data = censored_normal.to_dict() + assert data == { + "class": "Censored", + "data": {"dist": normal.to_dict(), "lower": 0, "upper": float("inf")}, + } + + +def test_deserialize_censored() -> None: + data = { + "class": "Censored", + "data": { + "dist": { + "dist": "Normal", + }, + "lower": 0, + "upper": float("inf"), + }, + } + + instance = deserialize(data) + assert isinstance(instance, Censored) + assert isinstance(instance.distribution, Prior) + assert instance.lower == 0 + assert instance.upper == float("inf") + + +class ArbitrarySerializable(Arbitrary): + def to_dict(self): + return {"dims": self.dims} + + +@pytest.fixture +def arbitrary_serialized_data() -> dict: + return {"dims": ("channel",)} + + +def test_create_prior_with_arbitrary_serializable(arbitrary_serialized_data) -> None: + dist = Prior( + "Normal", + mu=ArbitrarySerializable(dims=("channel",)), + sigma=1, + dims=("channel", "geo"), + ) + + assert dist.to_dict() == { + "dist": "Normal", + "kwargs": { + "mu": arbitrary_serialized_data, + "sigma": 1, + }, + "dims": ("channel", "geo"), + } + + +@pytest.fixture +def register_arbitrary_deserialization(): + register_deserialization( + lambda data: isinstance(data, dict) and data.keys() == {"dims"}, + lambda data: ArbitrarySerializable(**data), + ) + + yield + + DESERIALIZERS.pop() + + +def test_deserialize_arbitrary_within_prior( + arbitrary_serialized_data, + register_arbitrary_deserialization, +) -> None: + data = { + "dist": "Normal", + "kwargs": { + "mu": arbitrary_serialized_data, + "sigma": 1, + }, + "dims": ("channel", "geo"), + } + + dist = deserialize(data) + assert isinstance(dist["mu"], ArbitrarySerializable) + assert dist["mu"].dims == ("channel",) + + +def test_censored_with_tensor_variable() -> None: + normal = Prior("Normal", dims="channel") + lower = pt.as_tensor_variable([0, 1, 2]) + censored_normal = Censored(normal, lower=lower) + + assert censored_normal.to_dict() == { + "class": "Censored", + "data": { + "dist": normal.to_dict(), + "lower": [0, 1, 2], + "upper": float("inf"), + }, + } + + +def test_censored_dims_setter() -> None: + normal = Prior("Normal", dims="channel") + censored_normal = Censored(normal, lower=0) + censored_normal.dims = "date" + assert normal.dims == ("date",) + + +class ModelData(NamedTuple): + mu: float + observed: list[float] + + +@pytest.fixture(scope="session") +def model_data() -> ModelData: + return ModelData(mu=0, observed=[0, 1, 2, 3, 4]) + + +@pytest.fixture(scope="session") +def normal_model_with_censored_API(model_data) -> pm.Model: + coords = {"idx": range(len(model_data.observed))} + with pm.Model(coords=coords) as model: + sigma = Prior("HalfNormal") + normal = Prior("Normal", sigma=sigma, dims="idx") + Censored(normal, lower=0).create_likelihood_variable( + "censored_normal", + mu=model_data.mu, + observed=model_data.observed, + ) + + return model + + +@pytest.fixture(scope="session") +def normal_model_with_censored_logp(normal_model_with_censored_API): + return normal_model_with_censored_API.compile_logp() + + +@pytest.fixture(scope="session") +def expected_normal_model(model_data) -> pm.Model: + n_points = len(model_data.observed) + with pm.Model() as expected_model: + sigma = pm.HalfNormal("censored_normal_sigma") + normal = pm.Normal.dist(mu=model_data.mu, sigma=sigma, shape=n_points) + pm.Censored( + "censored_normal", + normal, + lower=0, + upper=np.inf, + observed=model_data.observed, + ) + + return expected_model + + +@pytest.fixture(scope="session") +def expected_normal_model_logp(expected_normal_model): + return expected_normal_model.compile_logp() + + +@pytest.mark.parametrize("sigma_log__", [-10, -5, -2.5, 0, 2.5, 5, 10]) +def test_censored_normal_logp( + sigma_log__, + normal_model_with_censored_logp, + expected_normal_model_logp, +) -> None: + points = {"censored_normal_sigma_log__": sigma_log__} + normal_model_logp = normal_model_with_censored_logp(points) + expected_model_logp = expected_normal_model_logp(points) + np.testing.assert_allclose(normal_model_logp, expected_model_logp) + + +@pytest.mark.parametrize( + "mu", + [ + 0, + np.arange(10), + ], + ids=["scalar", "vector"], +) +def test_censored_logp(mu) -> None: + n_points = 10 + observed = np.zeros(n_points) + coords = {"idx": range(n_points)} + with pm.Model(coords=coords) as model: + normal = Prior("Normal", dims="idx") + Censored(normal, lower=0).create_likelihood_variable( + "censored_normal", + observed=observed, + mu=mu, + ) + logp = model.compile_logp() + + with pm.Model() as expected_model: + pm.Censored( + "censored_normal", + pm.Normal.dist(mu=mu, sigma=1, shape=n_points), + lower=0, + upper=np.inf, + observed=observed, + ) + expected_logp = expected_model.compile_logp() + + point = {} + np.testing.assert_allclose(logp(point), expected_logp(point)) + + +def test_scaled_initializes_correctly() -> None: + """Test that the Scaled class initializes correctly.""" + normal = Prior("Normal", mu=0, sigma=1) + scaled = Scaled(normal, factor=2.0) + + assert scaled.dist == normal + assert scaled.factor == 2.0 + + +def test_scaled_dims_property() -> None: + """Test that the dims property returns the dimensions of the underlying distribution.""" + normal = Prior("Normal", mu=0, sigma=1, dims="channel") + scaled = Scaled(normal, factor=2.0) + + assert scaled.dims == ("channel",) + + # Test with multiple dimensions + normal.dims = ("channel", "geo") + assert scaled.dims == ("channel", "geo") + + +def test_scaled_create_variable() -> None: + """Test that the create_variable method properly scales the variable.""" + normal = Prior("Normal", mu=0, sigma=1) + scaled = Scaled(normal, factor=2.0) + + with pm.Model() as model: + scaled_var = scaled.create_variable("scaled_var") + + # Check that both the scaled and unscaled variables exist + assert "scaled_var" in model + assert "scaled_var_unscaled" in model + + # The deterministic node should be the scaled variable + assert model["scaled_var"] == scaled_var + + +def test_scaled_creates_correct_dimensions() -> None: + """Test that the scaled variable has the correct dimensions.""" + normal = Prior("Normal", dims="channel") + scaled = Scaled(normal, factor=2.0) + + coords = {"channel": ["A", "B", "C"]} + with pm.Model(coords=coords): + scaled_var = scaled.create_variable("scaled_var") + + # Check that the scaled variable has the correct dimensions + assert fast_eval(scaled_var).shape == (3,) + + +def test_scaled_applies_factor() -> None: + """Test that the scaling factor is correctly applied.""" + normal = Prior("Normal", mu=0, sigma=1) + factor = 3.5 + scaled = Scaled(normal, factor=factor) + + # Sample from prior to verify scaling + prior = sample_prior(scaled, samples=10, name="scaled_var") + df_prior = prior.to_dataframe() + + # Check that scaled values are original values times the factor + unscaled_values = df_prior["scaled_var_unscaled"].to_numpy() + scaled_values = df_prior["scaled_var"].to_numpy() + np.testing.assert_allclose(scaled_values, unscaled_values * factor) + + +def test_scaled_with_tensor_factor() -> None: + """Test that the Scaled class works with a tensor factor.""" + normal = Prior("Normal", mu=0, sigma=1) + factor = pt.as_tensor_variable(2.5) + scaled = Scaled(normal, factor=factor) + + # Sample from prior to verify tensor scaling + prior = sample_prior(scaled, samples=10, name="scaled_var") + df_prior = prior.to_dataframe() + + # Check that scaled values are original values times the factor + unscaled_values = df_prior["scaled_var_unscaled"].to_numpy() + scaled_values = df_prior["scaled_var"].to_numpy() + np.testing.assert_allclose(scaled_values, unscaled_values * 2.5) + + +def test_scaled_with_hierarchical_prior() -> None: + """Test that the Scaled class works with hierarchical priors.""" + normal = Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"), dims="channel") + scaled = Scaled(normal, factor=2.0) + + coords = {"channel": ["A", "B", "C"]} + with pm.Model(coords=coords) as model: + scaled.create_variable("scaled_var") + + # Check that all necessary variables were created + assert "scaled_var" in model + assert "scaled_var_unscaled" in model + assert "scaled_var_unscaled_mu" in model + assert "scaled_var_unscaled_sigma" in model + + +def test_scaled_sample_prior() -> None: + """Test that sample_prior works with the Scaled class.""" + normal = Prior("Normal", dims="channel") + scaled = Scaled(normal, factor=2.0) + + coords = {"channel": ["A", "B", "C"]} + prior = sample_prior(scaled, coords=coords, draws=25, name="scaled_var") + + assert isinstance(prior, xr.Dataset) + assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} + assert "scaled_var" in prior + assert "scaled_var_unscaled" in prior From 049d02bd47af39376eb57202dd127dc3718735eb Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 4 May 2025 09:45:27 -0400 Subject: [PATCH 03/15] remove references and bring back deserialize --- pymc_extras/deserialize.py | 5 ++--- pymc_extras/prior.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/pymc_extras/deserialize.py b/pymc_extras/deserialize.py index e73cc3049..8172cc231 100644 --- a/pymc_extras/deserialize.py +++ b/pymc_extras/deserialize.py @@ -1,12 +1,11 @@ -"""Deserialize into a PyMC-Marketing object. +"""Deserialize dictionaries into Python objects. This is a two step process: 1. Determine if the data is of the correct type. 2. Deserialize the data into a python object. -This is used to deserialize JSON data into PyMC-Marketing objects -throughout the package. +This is used to deserialize JSON data for PyMC-Marketing. Examples -------- diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index 715df6874..53ff43e5c 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -82,6 +82,7 @@ def custom_transform(x): from __future__ import annotations import copy + from collections.abc import Callable from inspect import signature from typing import Any, Protocol, runtime_checkable @@ -90,6 +91,7 @@ def custom_transform(x): import pymc as pm import pytensor.tensor as pt import xarray as xr + from pydantic import InstanceOf, validate_call from pydantic.dataclasses import dataclass from pymc.distributions.shape_utils import Dims @@ -1346,3 +1348,15 @@ def create_variable(self, name: str) -> pt.TensorVariable: """ var = self.dist.create_variable(f"{name}_unscaled") return pm.Deterministic(name, var * self.factor, dims=self.dims) + + +def _is_prior_type(data: dict) -> bool: + return "dist" in data + + +def _is_censored_type(data: dict) -> bool: + return data.keys() == {"class", "data"} and data["class"] == "Censored" + + +register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_dict) +register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict) From ab02a57b6133c6c63edb36877da4522a5ac0c3f2 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 4 May 2025 09:46:24 -0400 Subject: [PATCH 04/15] add to requirements --- conda-envs/environment-test.yml | 1 + conda-envs/windows-environment-test.yml | 1 + requirements-dev.txt | 2 ++ requirements.txt | 1 + 4 files changed, 5 insertions(+) diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 260c7b0e5..84e467ec8 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -10,6 +10,7 @@ dependencies: - xhistogram - statsmodels - numba<=0.60.0 +- pydantic>=2.0.0 - pip - pip: - blackjax diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 6a92aea55..7b53c4342 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -11,6 +11,7 @@ dependencies: - statsmodels - numba<=0.60.0 - pymc>=5.21 +- pydantic>=2.0.0 - pip: - blackjax - scikit-learn diff --git a/requirements-dev.txt b/requirements-dev.txt index a28518d8e..8bbe56426 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,3 +3,5 @@ blackjax # Used as benchmark for statespace models statsmodels +pydantic>=2.0.0 +preliz diff --git a/requirements.txt b/requirements.txt index 49c7d88af..9636089dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ pymc>=5.21.1 scikit-learn better-optimize +pydantic>=2.0.0 From 344a78793c060d2baf62eba5fc8b8ecf05c26cab Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 20 May 2025 09:17:23 -0400 Subject: [PATCH 05/15] add check for list --- pymc_extras/prior.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index 53ff43e5c..792379ad4 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -443,6 +443,9 @@ def dims(self, dims) -> None: if isinstance(dims, str): dims = (dims,) + if isinstance(dims, list): + dims = tuple(dims) + self._dims = dims or () self._param_dims_work() From 6d6ff783434bd3cdc0cb0785188a3bf7f2b3ddb4 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 20 May 2025 09:28:06 -0400 Subject: [PATCH 06/15] add pydantic --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 37c4bc631..529f99488 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,8 @@ dynamic = ["version"] # specify the version in the __init__.py file dependencies = [ "pymc>=5.21.1", "scikit-learn", - "better-optimize" + "better-optimize", + "pydantic>=2.0.0", ] [project.optional-dependencies] From 5b6badea2e672313622a78716e1dfd27a3df0e08 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 20 May 2025 09:47:03 -0400 Subject: [PATCH 07/15] add pytest-mock --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 529f99488..2b8120b6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ complete = [ ] dev = [ "pytest>=6.0", + "pytest-mock", "dask[all]<2025.1.1", "blackjax", "statsmodels", From ac19b3a4c74375b2ebf52a9268eab571ba3fde18 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 20 May 2025 09:47:17 -0400 Subject: [PATCH 08/15] fix tests --- tests/test_prior.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/test_prior.py b/tests/test_prior.py index d0a3490ea..70729b9f9 100644 --- a/tests/test_prior.py +++ b/tests/test_prior.py @@ -8,6 +8,7 @@ import xarray as xr from graphviz.graphs import Digraph +from preliz.distributions import distributions as preliz_distributions from pydantic import ValidationError from pymc.model_graph import fast_eval @@ -31,8 +32,6 @@ sample_prior, ) -pz = pytest.importorskip("preliz") - @pytest.mark.parametrize( "x, dims, desired_dims, expected_fn", @@ -133,7 +132,7 @@ def test_doesnt_check_validity_values() -> None: def test_preliz() -> None: var = Prior("Normal", mu=0, sigma=1) dist = var.preliz - assert isinstance(dist, pz.distributions.Distribution) + assert isinstance(dist, preliz_distributions.Distribution) @pytest.mark.parametrize( @@ -442,7 +441,7 @@ def test_sample_prior() -> None: ) coords = {"channel": ["A", "B", "C"]} - prior = var.sample_prior(coords=coords, samples=25) + prior = var.sample_prior(coords=coords, draws=25) assert isinstance(prior, xr.Dataset) assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} @@ -614,7 +613,7 @@ def test_custom_transform() -> None: register_tensor_transform(new_transform_name, lambda x: x**2) dist = Prior("Normal", transform=new_transform_name) - prior = dist.sample_prior(samples=10) + prior = dist.sample_prior(draws=10) df_prior = prior.to_dataframe() np.testing.assert_array_equal(df_prior["var"].to_numpy(), df_prior["var_raw"].to_numpy() ** 2) @@ -625,7 +624,7 @@ def test_custom_transform_comes_first() -> None: register_tensor_transform("square", lambda x: 2 * x) dist = Prior("Normal", transform="square") - prior = dist.sample_prior(samples=10) + prior = dist.sample_prior(draws=10) df_prior = prior.to_dataframe() np.testing.assert_array_equal(df_prior["var"].to_numpy(), 2 * df_prior["var_raw"].to_numpy()) @@ -763,7 +762,7 @@ def test_censored_sample_prior() -> None: censored_normal = Censored(normal, lower=0) coords = {"channel": ["A", "B", "C"]} - prior = censored_normal.sample_prior(coords=coords, samples=25) + prior = censored_normal.sample_prior(coords=coords, draws=25) assert isinstance(prior, xr.Dataset) assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} @@ -1089,7 +1088,7 @@ def test_scaled_applies_factor() -> None: scaled = Scaled(normal, factor=factor) # Sample from prior to verify scaling - prior = sample_prior(scaled, samples=10, name="scaled_var") + prior = sample_prior(scaled, draws=10, name="scaled_var") df_prior = prior.to_dataframe() # Check that scaled values are original values times the factor @@ -1105,7 +1104,7 @@ def test_scaled_with_tensor_factor() -> None: scaled = Scaled(normal, factor=factor) # Sample from prior to verify tensor scaling - prior = sample_prior(scaled, samples=10, name="scaled_var") + prior = sample_prior(scaled, draws=10, name="scaled_var") df_prior = prior.to_dataframe() # Check that scaled values are original values times the factor From 6c82064bd731cc8d1c5e218b05400efe1a570787 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 20 May 2025 09:47:45 -0400 Subject: [PATCH 09/15] add preliz --- conda-envs/environment-test.yml | 1 + pyproject.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 27ac5ae13..f1abbec89 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -14,6 +14,7 @@ dependencies: - pytest-cov - libgcc<15 - pydantic>=2.0.0 + - preliz - pip - pip: - jax diff --git a/pyproject.toml b/pyproject.toml index 2b8120b6b..646cc0fd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "scikit-learn", "better-optimize", "pydantic>=2.0.0", + "preliz", ] [project.optional-dependencies] From 25ace5aedabad5bea2325454cf2eb57de7cc6700 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 20 May 2025 09:53:39 -0400 Subject: [PATCH 10/15] remove references to pymc-marketig --- pymc_extras/deserialize.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pymc_extras/deserialize.py b/pymc_extras/deserialize.py index 8172cc231..ac58848e5 100644 --- a/pymc_extras/deserialize.py +++ b/pymc_extras/deserialize.py @@ -5,11 +5,9 @@ 1. Determine if the data is of the correct type. 2. Deserialize the data into a python object. -This is used to deserialize JSON data for PyMC-Marketing. - Examples -------- -Make use of the already registered PyMC-Marketing deserializers: +Make use of the already registered deserializers: .. code-block:: python @@ -183,10 +181,6 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None: Use the :func:`deserialize` function to then deserialize data using all registered deserialize functions. - Classes from PyMC-Marketing have their deserialization mappings registered - automatically. However, custom classes will need to be registered manually - using this function before they can be deserialized. - Parameters ---------- is_type : Callable[[Any], bool] From ebb0b61e97daea54cb6122d603f1ac8c37274b89 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 20 May 2025 17:34:04 -0400 Subject: [PATCH 11/15] install test for tests --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 65d63e40c..d6a776df2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -53,7 +53,7 @@ jobs: cache-environment: true - name: Install pymc-extras run: | - pip install -e . + pip install -e ".[test]" python --version - name: Run tests run: | @@ -97,7 +97,7 @@ jobs: cache-environment: true - name: Install pymc-extras run: | - pip install -e . + pip install -e ".[test]" python --version - name: Run tests # This job uses a cmd shell, therefore the environment variable syntax is different! From 60add08ffb6b8b5f50217e2027ea7ed8b2e7c994 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 20 May 2025 17:41:47 -0400 Subject: [PATCH 12/15] add to the docs --- docs/api_reference.rst | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 5d4f62bb2..fd995a853 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -46,6 +46,31 @@ Distributions Skellam histogram_approximation +Prior +===== + +.. currentmodule:: pymc_extras.prior +.. autosummary:: + :toctree: generated/ + + create_dim_handler + handle_dims + Prior + VariableFactory + sample_prior + Censored + Scaled + +Deserialize +=========== + +.. currentmodule:: pymc_extras.deserialize +.. autosummary:: + :toctree: generated/ + + deserialize + register_deserialization + Deserializer Transforms ========== From 3839bff6e30c019330ac776d41d506654687d2ce Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 20 May 2025 19:03:07 -0400 Subject: [PATCH 13/15] fix the typo --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d6a776df2..6ff852ebc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -53,7 +53,7 @@ jobs: cache-environment: true - name: Install pymc-extras run: | - pip install -e ".[test]" + pip install -e ".[dev]" python --version - name: Run tests run: | @@ -97,7 +97,7 @@ jobs: cache-environment: true - name: Install pymc-extras run: | - pip install -e ".[test]" + pip install -e ".[dev]" python --version - name: Run tests # This job uses a cmd shell, therefore the environment variable syntax is different! From b9868bb69e920f588ab84137981282a43a708ed7 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 27 May 2025 12:05:02 -0400 Subject: [PATCH 14/15] remove deserialize module --- pymc_extras/deserialize.py | 224 ------------------------------------- pymc_extras/prior.py | 179 ----------------------------- tests/test_deserialize.py | 59 ---------- tests/test_prior.py | 222 ------------------------------------ 4 files changed, 684 deletions(-) delete mode 100644 pymc_extras/deserialize.py delete mode 100644 tests/test_deserialize.py diff --git a/pymc_extras/deserialize.py b/pymc_extras/deserialize.py deleted file mode 100644 index ac58848e5..000000000 --- a/pymc_extras/deserialize.py +++ /dev/null @@ -1,224 +0,0 @@ -"""Deserialize dictionaries into Python objects. - -This is a two step process: - -1. Determine if the data is of the correct type. -2. Deserialize the data into a python object. - -Examples --------- -Make use of the already registered deserializers: - -.. code-block:: python - - from pymc_extras.deserialize import deserialize - - prior_class_data = { - "dist": "Normal", - "kwargs": {"mu": 0, "sigma": 1} - } - prior = deserialize(prior_class_data) - # Prior("Normal", mu=0, sigma=1) - -Register custom class deserialization: - -.. code-block:: python - - from pymc_extras.deserialize import register_deserialization - - class MyClass: - def __init__(self, value: int): - self.value = value - - def to_dict(self) -> dict: - # Example of what the to_dict method might look like. - return {"value": self.value} - - register_deserialization( - is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int), - deserialize=lambda data: MyClass(value=data["value"]), - ) - -Deserialize data into that custom class: - -.. code-block:: python - - from pymc_extras.deserialize import deserialize - - data = {"value": 42} - obj = deserialize(data) - assert isinstance(obj, MyClass) - - -""" - -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any - -IsType = Callable[[Any], bool] -Deserialize = Callable[[Any], Any] - - -@dataclass -class Deserializer: - """Object to store information required for deserialization. - - All deserializers should be stored via the :func:`register_deserialization` function - instead of creating this object directly. - - Attributes - ---------- - is_type : IsType - Function to determine if the data is of the correct type. - deserialize : Deserialize - Function to deserialize the data. - - Examples - -------- - .. code-block:: python - - from typing import Any - - class MyClass: - def __init__(self, value: int): - self.value = value - - from pymc_extras.deserialize import Deserializer - - def is_type(data: Any) -> bool: - return data.keys() == {"value"} and isinstance(data["value"], int) - - def deserialize(data: dict) -> MyClass: - return MyClass(value=data["value"]) - - deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize) - - """ - - is_type: IsType - deserialize: Deserialize - - -DESERIALIZERS: list[Deserializer] = [] - - -class DeserializableError(Exception): - """Error raised when data cannot be deserialized.""" - - def __init__(self, data: Any): - self.data = data - super().__init__( - f"Couldn't deserialize {data}. Use register_deserialization to add a deserialization mapping." - ) - - -def deserialize(data: Any) -> Any: - """Deserialize a dictionary into a Python object. - - Use the :func:`register_deserialization` function to add custom deserializations. - - Deserialization is a two step process due to the dynamic nature of the data: - - 1. Determine if the data is of the correct type. - 2. Deserialize the data into a Python object. - - Each registered deserialization is checked in order until one is found that can - deserialize the data. If no deserialization is found, a :class:`DeserializableError` is raised. - - A :class:`DeserializableError` is raised when the data fails to be deserialized - by any of the registered deserializers. - - Parameters - ---------- - data : Any - The data to deserialize. - - Returns - ------- - Any - The deserialized object. - - Raises - ------ - DeserializableError - Raised when the data doesn't match any registered deserializations - or fails to be deserialized. - - Examples - -------- - Deserialize a :class:`pymc_extras.prior.Prior` object: - - .. code-block:: python - - from pymc_extras.deserialize import deserialize - - data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}} - prior = deserialize(data) - # Prior("Normal", mu=0, sigma=1) - - """ - for mapping in DESERIALIZERS: - try: - is_type = mapping.is_type(data) - except Exception: - is_type = False - - if not is_type: - continue - - try: - return mapping.deserialize(data) - except Exception as e: - raise DeserializableError(data) from e - else: - raise DeserializableError(data) - - -def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None: - """Register an arbitrary deserialization. - - Use the :func:`deserialize` function to then deserialize data using all registered - deserialize functions. - - Parameters - ---------- - is_type : Callable[[Any], bool] - Function to determine if the data is of the correct type. - deserialize : Callable[[dict], Any] - Function to deserialize the data of that type. - - Examples - -------- - Register a custom class deserialization: - - .. code-block:: python - - from pymc_extras.deserialize import register_deserialization - - class MyClass: - def __init__(self, value: int): - self.value = value - - def to_dict(self) -> dict: - # Example of what the to_dict method might look like. - return {"value": self.value} - - register_deserialization( - is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int), - deserialize=lambda data: MyClass(value=data["value"]), - ) - - Use that custom class deserialization: - - .. code-block:: python - - from pymc_extras.deserialize import deserialize - - data = {"value": 42} - obj = deserialize(data) - assert isinstance(obj, MyClass) - - """ - mapping = Deserializer(is_type=is_type, deserialize=deserialize) - DESERIALIZERS.append(mapping) diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index 792379ad4..20702bfc9 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -96,8 +96,6 @@ def custom_transform(x): from pydantic.dataclasses import dataclass from pymc.distributions.shape_utils import Dims -from pymc_extras.deserialize import deserialize, register_deserialization - class UnsupportedShapeError(Exception): """Error for when the shapes from variables are not compatible.""" @@ -687,143 +685,6 @@ def preliz(self): return getattr(pz, self.distribution)(**self.parameters) - def to_dict(self) -> dict[str, Any]: - """Convert the prior to dictionary format. - - Returns - ------- - dict[str, Any] - The dictionary format of the prior. - - Examples - -------- - Convert a prior to the dictionary format. - - .. code-block:: python - - from pymc_extras.prior import Prior - - dist = Prior("Normal", mu=0, sigma=1) - - dist.to_dict() - # {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}} - - Convert a hierarchical prior to the dictionary format. - - .. code-block:: python - - dist = Prior( - "Normal", - mu=Prior("Normal"), - sigma=Prior("HalfNormal"), - dims="channel", - ) - - dist.to_dict() - # { - # "dist": "Normal", - # "kwargs": { - # "mu": {"dist": "Normal"}, - # "sigma": {"dist": "HalfNormal"}, - # }, - # "dims": "channel", - # } - - """ - data: dict[str, Any] = { - "dist": self.distribution, - } - if self.parameters: - - def handle_value(value): - if isinstance(value, Prior): - return value.to_dict() - - if isinstance(value, pt.TensorVariable): - value = value.eval() - - if isinstance(value, np.ndarray): - return value.tolist() - - if hasattr(value, "to_dict"): - return value.to_dict() - - return value - - data["kwargs"] = { - param: handle_value(value) for param, value in self.parameters.items() - } - if not self.centered: - data["centered"] = False - - if self.dims: - data["dims"] = self.dims - - if self.transform: - data["transform"] = self.transform - - return data - - @classmethod - def from_dict(cls, data) -> Prior: - """Create a Prior from the dictionary format. - - Parameters - ---------- - data : dict[str, Any] - The dictionary format of the prior. - - Returns - ------- - Prior - The prior distribution. - - Examples - -------- - Convert prior in the dictionary format to a Prior instance. - - .. code-block:: python - - from pymc_extras.prior import Prior - - data = { - "dist": "Normal", - "kwargs": {"mu": 0, "sigma": 1}, - } - - dist = Prior.from_dict(data) - dist - # Prior("Normal", mu=0, sigma=1) - - """ - if not isinstance(data, dict): - msg = ( - "Must be a dictionary representation of a prior distribution. " - f"Not of type: {type(data)}" - ) - raise ValueError(msg) - - dist = data["dist"] - kwargs = data.get("kwargs", {}) - - def handle_value(value): - if isinstance(value, dict): - return deserialize(value) - - if isinstance(value, list): - return np.array(value) - - return value - - kwargs = {param: handle_value(value) for param, value in kwargs.items()} - centered = data.get("centered", True) - dims = data.get("dims") - if isinstance(dims, list): - dims = tuple(dims) - transform = data.get("transform") - - return cls(dist, dims=dims, centered=centered, transform=transform, **kwargs) - def constrain(self, lower: float, upper: float, mass: float = 0.95, kwargs=None) -> Prior: """Create a new prior with a given mass constrained within the given bounds. @@ -1161,34 +1022,6 @@ def create_variable(self, name: str) -> pt.TensorVariable: dims=self.dims, ) - def to_dict(self) -> dict[str, Any]: - """Convert the censored distribution to a dictionary.""" - - def handle_value(value): - if isinstance(value, pt.TensorVariable): - return value.eval().tolist() - - return value - - return { - "class": "Censored", - "data": { - "dist": self.distribution.to_dict(), - "lower": handle_value(self.lower), - "upper": handle_value(self.upper), - }, - } - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Censored: - """Create a censored distribution from a dictionary.""" - data = data["data"] - return cls( # type: ignore - distribution=Prior.from_dict(data["dist"]), - lower=data["lower"], - upper=data["upper"], - ) - def sample_prior( self, coords=None, @@ -1351,15 +1184,3 @@ def create_variable(self, name: str) -> pt.TensorVariable: """ var = self.dist.create_variable(f"{name}_unscaled") return pm.Deterministic(name, var * self.factor, dims=self.dims) - - -def _is_prior_type(data: dict) -> bool: - return "dist" in data - - -def _is_censored_type(data: dict) -> bool: - return data.keys() == {"class", "data"} and data["class"] == "Censored" - - -register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_dict) -register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict) diff --git a/tests/test_deserialize.py b/tests/test_deserialize.py deleted file mode 100644 index 2fcc28a13..000000000 --- a/tests/test_deserialize.py +++ /dev/null @@ -1,59 +0,0 @@ -import pytest - -from pymc_extras.deserialize import ( - DESERIALIZERS, - DeserializableError, - deserialize, - register_deserialization, -) - - -@pytest.mark.parametrize( - "unknown_data", - [ - {"unknown": 1}, - {"dist": "Normal", "kwargs": {"something": "else"}}, - 1, - ], - ids=["unknown_structure", "prior_like", "non_dict"], -) -def test_unknown_type_raises(unknown_data) -> None: - match = "Couldn't deserialize" - with pytest.raises(DeserializableError, match=match): - deserialize(unknown_data) - - -class ArbitraryObject: - def __init__(self, code: str): - self.code = code - self.value = 1 - - -@pytest.fixture -def register_arbitrary_object(): - register_deserialization( - is_type=lambda data: data.keys() == {"code"}, - deserialize=lambda data: ArbitraryObject(code=data["code"]), - ) - - yield - - DESERIALIZERS.pop() - - -def test_registration(register_arbitrary_object) -> None: - instance = deserialize({"code": "test"}) - - assert isinstance(instance, ArbitraryObject) - assert instance.code == "test" - - -def test_registeration_mixup() -> None: - data_that_looks_like_prior = { - "dist": "Normal", - "kwargs": {"something": "else"}, - } - - match = "Couldn't deserialize" - with pytest.raises(DeserializableError, match=match): - deserialize(data_that_looks_like_prior) diff --git a/tests/test_prior.py b/tests/test_prior.py index 70729b9f9..c534da021 100644 --- a/tests/test_prior.py +++ b/tests/test_prior.py @@ -12,11 +12,6 @@ from pydantic import ValidationError from pymc.model_graph import fast_eval -from pymc_extras.deserialize import ( - DESERIALIZERS, - deserialize, - register_deserialization, -) from pymc_extras.prior import ( Censored, MuAlreadyExistsError, @@ -301,57 +296,6 @@ def test_transform() -> None: assert fast_eval(model[var_name]).shape == dim -def test_to_dict(large_var) -> None: - data = large_var.to_dict() - - assert data == { - "dist": "Normal", - "kwargs": { - "mu": { - "dist": "Normal", - "kwargs": { - "mu": { - "dist": "Normal", - "kwargs": { - "mu": 1, - }, - }, - "sigma": { - "dist": "HalfNormal", - }, - }, - "centered": False, - "dims": ("channel",), - }, - "sigma": { - "dist": "HalfNormal", - "kwargs": { - "sigma": { - "dist": "HalfNormal", - }, - }, - "dims": ("geo",), - }, - }, - "dims": ("geo", "channel"), - } - - -def test_to_dict_numpy() -> None: - var = Prior("Normal", mu=np.array([0, 10, 20]), dims="channel") - assert var.to_dict() == { - "dist": "Normal", - "kwargs": { - "mu": [0, 10, 20], - }, - "dims": ("channel",), - } - - -def test_dict_round_trip(large_var) -> None: - assert Prior.from_dict(large_var.to_dict()) == large_var - - def test_constrain_with_transform_error() -> None: var = Prior("Normal", transform="sigmoid") @@ -421,16 +365,6 @@ def mmm_default_model_config(): } -def test_backwards_compat(mmm_default_model_config) -> None: - result = {param: Prior.from_dict(value) for param, value in mmm_default_model_config.items()} - assert result == { - "intercept": Prior("Normal", mu=0, sigma=2), - "likelihood": Prior("Normal", sigma=Prior("HalfNormal", sigma=2)), - "gamma_control": Prior("Normal", mu=0, sigma=2, dims="control"), - "gamma_fourier": Prior("Laplace", mu=0, b=1, dims="fourier_mode"), - } - - def test_sample_prior() -> None: var = Prior( "Normal", @@ -466,46 +400,6 @@ def test_to_graph() -> None: assert isinstance(G, Digraph) -def test_from_dict_list() -> None: - data = { - "dist": "Normal", - "kwargs": { - "mu": [0, 1, 2], - "sigma": 1, - }, - "dims": "channel", - } - - var = Prior.from_dict(data) - assert var.dims == ("channel",) - assert isinstance(var["mu"], np.ndarray) - np.testing.assert_array_equal(var["mu"], [0, 1, 2]) - - -def test_from_dict_list_dims() -> None: - data = { - "dist": "Normal", - "kwargs": { - "mu": 0, - "sigma": 1, - }, - "dims": ["channel", "geo"], - } - - var = Prior.from_dict(data) - assert var.dims == ("channel", "geo") - - -def test_to_dict_transform() -> None: - dist = Prior("Normal", transform="sigmoid") - - data = dist.to_dict() - assert data == { - "dist": "Normal", - "transform": "sigmoid", - } - - def test_equality_non_prior() -> None: dist = Prior("Normal") @@ -632,19 +526,6 @@ def test_custom_transform_comes_first() -> None: clear_custom_transforms() -def test_serialize_with_pytensor() -> None: - sigma = pt.arange(1, 4) - dist = Prior("Normal", mu=0, sigma=sigma) - - assert dist.to_dict() == { - "dist": "Normal", - "kwargs": { - "mu": 0, - "sigma": [1, 2, 3], - }, - } - - def test_zsn_non_centered() -> None: try: Prior("ZeroSumNormal", sigma=1, centered=False) @@ -822,109 +703,6 @@ def test_censored_likelihood_already_has_mu() -> None: ) -def test_censored_to_dict() -> None: - normal = Prior("Normal", mu=0, sigma=1, dims="channel") - censored_normal = Censored(normal, lower=0) - - data = censored_normal.to_dict() - assert data == { - "class": "Censored", - "data": {"dist": normal.to_dict(), "lower": 0, "upper": float("inf")}, - } - - -def test_deserialize_censored() -> None: - data = { - "class": "Censored", - "data": { - "dist": { - "dist": "Normal", - }, - "lower": 0, - "upper": float("inf"), - }, - } - - instance = deserialize(data) - assert isinstance(instance, Censored) - assert isinstance(instance.distribution, Prior) - assert instance.lower == 0 - assert instance.upper == float("inf") - - -class ArbitrarySerializable(Arbitrary): - def to_dict(self): - return {"dims": self.dims} - - -@pytest.fixture -def arbitrary_serialized_data() -> dict: - return {"dims": ("channel",)} - - -def test_create_prior_with_arbitrary_serializable(arbitrary_serialized_data) -> None: - dist = Prior( - "Normal", - mu=ArbitrarySerializable(dims=("channel",)), - sigma=1, - dims=("channel", "geo"), - ) - - assert dist.to_dict() == { - "dist": "Normal", - "kwargs": { - "mu": arbitrary_serialized_data, - "sigma": 1, - }, - "dims": ("channel", "geo"), - } - - -@pytest.fixture -def register_arbitrary_deserialization(): - register_deserialization( - lambda data: isinstance(data, dict) and data.keys() == {"dims"}, - lambda data: ArbitrarySerializable(**data), - ) - - yield - - DESERIALIZERS.pop() - - -def test_deserialize_arbitrary_within_prior( - arbitrary_serialized_data, - register_arbitrary_deserialization, -) -> None: - data = { - "dist": "Normal", - "kwargs": { - "mu": arbitrary_serialized_data, - "sigma": 1, - }, - "dims": ("channel", "geo"), - } - - dist = deserialize(data) - assert isinstance(dist["mu"], ArbitrarySerializable) - assert dist["mu"].dims == ("channel",) - - -def test_censored_with_tensor_variable() -> None: - normal = Prior("Normal", dims="channel") - lower = pt.as_tensor_variable([0, 1, 2]) - censored_normal = Censored(normal, lower=lower) - - assert censored_normal.to_dict() == { - "class": "Censored", - "data": { - "dist": normal.to_dict(), - "lower": [0, 1, 2], - "upper": float("inf"), - }, - } - - def test_censored_dims_setter() -> None: normal = Prior("Normal", dims="channel") censored_normal = Censored(normal, lower=0) From bf0fe89d0431edf397e5132ab3c286f9e92d217f Mon Sep 17 00:00:00 2001 From: Will Dean Date: Thu, 29 May 2025 09:33:08 -0400 Subject: [PATCH 15/15] remove the deserialize docs --- docs/api_reference.rst | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index fd995a853..fbe157ce0 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -61,16 +61,6 @@ Prior Censored Scaled -Deserialize -=========== - -.. currentmodule:: pymc_extras.deserialize -.. autosummary:: - :toctree: generated/ - - deserialize - register_deserialization - Deserializer Transforms ==========