-
Notifications
You must be signed in to change notification settings - Fork 83
Support for custom priors via Prior class #488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
b35001b
b7300e7
a60035e
367c922
a9f821c
91aee00
dc20e3e
f51f994
7565b7b
4312dc9
57ba733
b57810a
1a0b078
787a10e
bcba49f
0650644
3c659d3
4be4cdd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
import pytensor.tensor as pt | ||
import xarray as xr | ||
from arviz import r2_score | ||
from pymc_extras.prior import Prior | ||
|
||
from causalpy.utils import round_num | ||
|
||
|
@@ -68,7 +69,13 @@ class PyMCModel(pm.Model): | |
Inference data... | ||
""" | ||
|
||
def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): | ||
default_priors: dict[str, Any] | ||
|
||
def __init__( | ||
self, | ||
sample_kwargs: Optional[Dict[str, Any]] = None, | ||
priors: dict[str, Any] | None = None, | ||
): | ||
""" | ||
:param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the | ||
:func:`pymc.sample` function. Defaults to an empty dictionary. | ||
|
@@ -77,6 +84,8 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): | |
self.idata = None | ||
self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {} | ||
|
||
self.priors = {**self.default_priors, **(priors or {})} | ||
|
||
def build_model(self, X, y, coords) -> None: | ||
"""Build the model, must be implemented by subclass.""" | ||
raise NotImplementedError("This method must be implemented by a subclass") | ||
|
@@ -237,6 +246,11 @@ class LinearRegression(PyMCModel): | |
Inference data... | ||
""" # noqa: W605 | ||
|
||
default_priors = { | ||
"beta": Prior("Normal", mu=0, sigma=50, dims="coeffs"), | ||
"y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1)), | ||
} | ||
|
||
def build_model(self, X, y, coords): | ||
""" | ||
Defines the PyMC model | ||
|
@@ -245,10 +259,9 @@ def build_model(self, X, y, coords): | |
self.add_coords(coords) | ||
X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) | ||
y = pm.Data("y", y, dims="obs_ind") | ||
beta = pm.Normal("beta", 0, 50, dims="coeffs") | ||
sigma = pm.HalfNormal("sigma", 1) | ||
williambdean marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
beta = self.priors["beta"].create_variable("beta") | ||
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind") | ||
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind") | ||
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) | ||
|
||
|
||
class WeightedSumFitter(PyMCModel): | ||
|
@@ -276,6 +289,10 @@ class WeightedSumFitter(PyMCModel): | |
Inference data... | ||
""" # noqa: W605 | ||
|
||
default_priors = { | ||
"y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1)), | ||
} | ||
|
||
def build_model(self, X, y, coords): | ||
""" | ||
Defines the PyMC model | ||
|
@@ -286,9 +303,8 @@ def build_model(self, X, y, coords): | |
X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) | ||
y = pm.Data("y", y[:, 0], dims="obs_ind") | ||
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs") | ||
|
||
sigma = pm.HalfNormal("sigma", 1) | ||
williambdean marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind") | ||
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind") | ||
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) | ||
|
||
|
||
class InstrumentalVariableRegression(PyMCModel): | ||
|
@@ -477,13 +493,17 @@ class PropensityScore(PyMCModel): | |
Inference... | ||
""" # noqa: W605 | ||
|
||
default_priors = { | ||
"b": Prior("Normal", mu=0, sigma=1, dims="coeffs"), | ||
} | ||
|
||
def build_model(self, X, t, coords): | ||
"Defines the PyMC propensity model" | ||
with self: | ||
self.add_coords(coords) | ||
X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"]) | ||
t_data = pm.Data("t", t.flatten(), dims="obs_ind") | ||
b = pm.Normal("b", mu=0, sigma=1, dims="coeffs") | ||
b = self.priors["b"].create_variable("b") | ||
mu = pm.math.dot(X_data, b) | ||
p = pm.Deterministic("p", pm.math.invlogit(mu)) | ||
pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,4 @@ dependencies: | |
- seaborn>=0.11.2 | ||
- statsmodels | ||
- xarray>=v2022.11.0 | ||
- pymc-extras>=0.2.7 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to add
@property
decorator here? Or is that remembered from it being done in thePyMCModel
base class?Getting an Pylance warning:
Type "dict[str, Prior]" is not assignable to declared type "property"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What line of code bring that on? Maybe having a setter will help?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, if you want a property, maybe we can have a setter method? (not a blocker for now and maybe create an issue?)