Skip to content

Implement limit case of R2D2M2CP for P #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 156 additions & 30 deletions pymc_experimental/distributions/multivariate/r2d2m2cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,25 @@

from typing import Sequence, Union

import numpy as np
import pymc as pm
import pytensor.tensor as pt

__all__ = ["R2D2M2CP"]


def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable):
def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable, psi_mask):
pi = pt.erfinv(2 * psi - 1)
f = (1 / (2 * pi**2 + 1)) ** 0.5
sigma = explained_var**0.5 * f
mu = sigma * pi * 2**0.5
return mu, sigma
if psi_mask is not None:
return (
pt.where(psi_mask, mu, pt.sign(pi) * explained_var**0.5),
pt.where(psi_mask, sigma, 0),
)
else:
return mu, sigma


def _R2D2M2CP_beta(
Expand All @@ -37,6 +44,7 @@ def _R2D2M2CP_beta(
phi: pt.TensorVariable,
psi: pt.TensorVariable,
*,
psi_mask,
dims: Union[str, Sequence[str]],
centered=False,
):
Expand All @@ -59,16 +67,141 @@ def _R2D2M2CP_beta(
"""
tau2 = r2 / (1 - r2)
explained_variance = phi * pt.expand_dims(tau2 * output_sigma**2, -1)
mu_param, std_param = _psivar2musigma(psi, explained_variance)
mu_param, std_param = _psivar2musigma(psi, explained_variance, psi_mask=psi_mask)
if not centered:
with pm.Model(name):
raw = pm.Normal("raw", dims=dims)
if psi_mask is not None and psi_mask.any():
# limit case where some probs are not 1 or 0
# setsubtensor is required
r_idx = psi_mask.nonzero()
with pm.Model("raw"):
raw = pm.Normal("masked", shape=len(r_idx[0]))
raw = pt.set_subtensor(pt.zeros_like(mu_param)[r_idx], raw)
raw = pm.Deterministic("raw", raw, dims=dims)
elif psi_mask is not None:
# all variables are deterministic
raw = pt.zeros_like(mu_param)
else:
raw = pm.Normal("raw", dims=dims)
beta = pm.Deterministic(name, (raw * std_param + mu_param) / input_sigma, dims=dims)
else:
beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims)
if psi_mask is not None and psi_mask.any():
# limit case where some probs are not 1 or 0
# setsubtensor is required
r_idx = psi_mask.nonzero()
with pm.Model(name):
mean = (mu_param / input_sigma)[r_idx]
sigma = (std_param / input_sigma)[r_idx]
masked = pm.Normal(
"masked",
mean,
sigma,
shape=len(r_idx[0]),
)
beta = pt.set_subtensor(mean, masked)
beta = pm.Deterministic(name, beta, dims=dims)
elif psi_mask is not None:
# all variables are deterministic
beta = pm.Deterministic(name, (mu_param / input_sigma), dims=dims)
else:
beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims)
return beta


def _broadcast_as_dims(*values, dims):
model = pm.modelcontext(None)
shape = [len(model.coords[d]) for d in dims]
ret = tuple(np.broadcast_to(v, shape) for v in values)
# strip output
if len(values) == 1:
ret = ret[0]
return ret


def _psi_masked(positive_probs, positive_probs_std, *, dims):
if not (
isinstance(positive_probs, pt.Constant) and isinstance(positive_probs_std, pt.Constant)
):
raise TypeError(
"Only constant values for positive_probs and positive_probs_std are accepted"
)
positive_probs, positive_probs_std = _broadcast_as_dims(
positive_probs.data, positive_probs_std.data, dims=dims
)
mask = ~np.bitwise_or(positive_probs == 1, positive_probs == 0)
if np.bitwise_and(~mask, positive_probs_std != 0).any():
raise ValueError("Can't have both positive_probs == '1 or 0' and positive_probs_std != 0")
if (~mask).any() and mask.any():
# limit case where some probs are not 1 or 0
# setsubtensor is required
r_idx = mask.nonzero()
with pm.Model("psi"):
psi = pm.Beta(
"masked",
mu=positive_probs[r_idx],
sigma=positive_probs_std[r_idx],
shape=len(r_idx[0]),
)
psi = pt.set_subtensor(pt.as_tensor(positive_probs)[r_idx], psi)
psi = pm.Deterministic("psi", psi, dims=dims)
elif (~mask).all():
# limit case where all the probs are limit case
psi = pt.as_tensor(positive_probs)
else:
psi = pm.Beta("psi", mu=positive_probs, sigma=positive_probs_std, dims=dims)
mask = None
return mask, psi


def _psi(positive_probs, positive_probs_std, *, dims):
if positive_probs_std is not None:
mask, psi = _psi_masked(
positive_probs=pt.as_tensor(positive_probs),
positive_probs_std=pt.as_tensor(positive_probs_std),
dims=dims,
)
else:
positive_probs = pt.as_tensor(positive_probs)
if not isinstance(positive_probs, pt.Constant):
raise TypeError("Only constant values for positive_probs are allowed")
psi = _broadcast_as_dims(positive_probs.data, dims=dims)
mask = np.atleast_1d(~np.bitwise_or(psi == 1, psi == 0))
if mask.all():
mask = None
return mask, psi


def _phi(
variables_importance,
variance_explained,
importance_concentration,
*,
dims,
):
*broadcast_dims, dim = dims
model = pm.modelcontext(None)
if variables_importance is not None:
if variance_explained is not None:
raise TypeError("Can't use variable importance with variance explained")
if len(model.coords[dim]) <= 1:
raise TypeError("Can't use variable importance with less than two variables")
variables_importance = pt.as_tensor(variables_importance)
if importance_concentration is not None:
variables_importance *= importance_concentration
return pm.Dirichlet("phi", variables_importance, dims=broadcast_dims + [dim])
elif variance_explained is not None:
if len(model.coords[dim]) <= 1:
raise TypeError("Can't use variance explained with less than two variables")
phi = pt.as_tensor(variance_explained)
else:
phi = 1 / len(model.coords[dim])
phi = _broadcast_as_dims(phi, dims=dims)
if importance_concentration is not None:
return pm.Dirichlet("phi", importance_concentration * phi, dims=broadcast_dims + [dim])
else:
return phi


def R2D2M2CP(
name,
output_sigma,
Expand All @@ -78,6 +211,7 @@ def R2D2M2CP(
r2,
variables_importance=None,
variance_explained=None,
importance_concentration=None,
r2_std=None,
positive_probs=0.5,
positive_probs_std=None,
Expand All @@ -102,6 +236,8 @@ def R2D2M2CP(
variance_explained : tensor, optional
Alternative estimate for variables importance which is point estimate of
variance explained, should sum up to one, by default None
importance_concentration : tensor, optional
Confidence around variance explained or variable importance estimate
r2_std : tensor, optional
Optional uncertainty over :math:`R^2`, by default None
positive_probs : tensor, optional
Expand All @@ -125,8 +261,8 @@ def R2D2M2CP(
-----
The R2D2M2CP prior is a modification of R2D2M2 prior.

- ``(R2D2M2)``CP is taken from https://arxiv.org/abs/2208.07132
- R2D2M2``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
- ``(R2D2M2)`` CP is taken from https://arxiv.org/abs/2208.07132
- R2D2M2 ``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)

Examples
--------
Expand Down Expand Up @@ -259,31 +395,20 @@ def R2D2M2CP(
input_sigma = pt.as_tensor(input_sigma)
output_sigma = pt.as_tensor(output_sigma)
with pm.Model(name) as model:
if variables_importance is not None:
if variance_explained is not None:
raise TypeError("Can't use variable importance with variance explained")
if len(model.coords[dim]) <= 1:
raise TypeError("Can't use variable importance with less than two variables")
phi = pm.Dirichlet(
"phi", pt.as_tensor(variables_importance), dims=broadcast_dims + [dim]
)
elif variance_explained is not None:
if len(model.coords[dim]) <= 1:
raise TypeError("Can't use variance explained with less than two variables")
phi = pt.as_tensor(variance_explained)
else:
phi = 1 / len(model.coords[dim])
if not all(isinstance(model.dim_lengths[d], pt.TensorConstant) for d in dims):
raise ValueError(f"{dims!r} should be constant length immutable dims")
if r2_std is not None:
r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims)
if positive_probs_std is not None:
psi = pm.Beta(
"psi",
mu=pt.as_tensor(positive_probs),
sigma=pt.as_tensor(positive_probs_std),
dims=broadcast_dims + [dim],
)
else:
psi = pt.as_tensor(positive_probs)
phi = _phi(
variables_importance=variables_importance,
variance_explained=variance_explained,
importance_concentration=importance_concentration,
dims=dims,
)
mask, psi = _psi(
positive_probs=positive_probs, positive_probs_std=positive_probs_std, dims=dims
)

beta = _R2D2M2CP_beta(
name,
output_sigma,
Expand All @@ -293,6 +418,7 @@ def R2D2M2CP(
psi,
dims=broadcast_dims + [dim],
centered=centered,
psi_mask=mask,
)
resid_sigma = (1 - r2) ** 0.5 * output_sigma
return resid_sigma, beta
Loading