Skip to content

Commit 96fbc30

Browse files
committed
Implement uncensor and forecast_timeseries model transformation
1 parent 103025b commit 96fbc30

File tree

2 files changed

+265
-0
lines changed

2 files changed

+265
-0
lines changed

pymc_experimental/model_transform.py

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
from pymc import DiracDelta
2+
from pymc.distributions.censored import CensoredRV
3+
from pymc.distributions.timeseries import AR, AutoRegressiveRV
4+
from pymc.model import Model
5+
from pytensor import shared
6+
from pytensor.graph import FunctionGraph, node_rewriter
7+
from pytensor.graph.basic import get_var_by_name
8+
from pytensor.graph.rewriting.basic import in2out
9+
10+
from pymc_experimental.utils.model_fgraph import (
11+
ModelObservedRV,
12+
ModelValuedVar,
13+
fgraph_from_model,
14+
model_free_rv,
15+
model_from_fgraph,
16+
model_named,
17+
)
18+
19+
__all__ = (
20+
"forecast_timeseries",
21+
"uncensor",
22+
)
23+
24+
25+
@node_rewriter(tracks=[ModelValuedVar])
26+
def uncensor_node_rewrite(fgraph, node):
27+
"""Rewrite that replaces censored variables by uncensored ones"""
28+
29+
(
30+
censored_rv,
31+
value,
32+
*dims,
33+
) = node.inputs
34+
if not isinstance(censored_rv.owner.op, CensoredRV):
35+
return
36+
37+
model_rv = node.outputs[0]
38+
base_rv = censored_rv.owner.inputs[0]
39+
uncensored_rv = node.op.make_node(base_rv, value, *dims).default_output()
40+
uncensored_rv.name = f"{model_rv.name}_uncensored"
41+
return [uncensored_rv]
42+
43+
44+
uncensor_rewrite = in2out(uncensor_node_rewrite)
45+
46+
47+
def uncensor(model: Model) -> Model:
48+
"""Replace censored variables in the model by uncensored equivalent.
49+
50+
Replaced variables have the same name as original ones with an additional "_uncensored" suffix.
51+
52+
.. code-block:: python
53+
54+
import arviz as az
55+
import pymc as pm
56+
from pymc_experimental.model_transform import uncensor
57+
58+
with pm.Model() as model:
59+
x = pm.Normal("x")
60+
dist_raw = pm.Normal.dist(x)
61+
y = pm.Censored("y", dist=dist_raw, lower=-1, upper=1, observed=[-1, 0.5, 1, 1, 1])
62+
idata = pm.sample()
63+
64+
with uncensor(model):
65+
idata_pp = pm.sample_posterior_predictive(idata, var_names=["y_uncensored"])
66+
67+
az.summary(idata_pp)
68+
"""
69+
fg = fgraph_from_model(model)
70+
71+
(_, nodes_changed, *_) = uncensor_rewrite.apply(fg)
72+
if not nodes_changed:
73+
raise RuntimeError("No censored variables were replaced by uncensored counterparts")
74+
75+
return model_from_fgraph(fg)
76+
77+
78+
@node_rewriter(tracks=[ModelValuedVar])
79+
def forecast_timeseries_node_rewrite(fgraph: FunctionGraph, node):
80+
"""Rewrite that replaces timeseries variables by new ones starting at the last timepoint(s)."""
81+
82+
(
83+
timeseries_rv,
84+
value,
85+
*dims,
86+
) = node.inputs
87+
if not isinstance(timeseries_rv.owner.op, AutoRegressiveRV):
88+
return
89+
90+
forecast_steps = get_var_by_name(fgraph.inputs, "forecast_steps_")
91+
if len(forecast_steps) != 1:
92+
return False
93+
94+
forecast_steps = forecast_steps[0]
95+
96+
op = timeseries_rv.owner.op
97+
model_rv = node.outputs[0]
98+
99+
# We cannot reference the variable we are planning to replace
100+
# Or it will introduce circularities in the graph
101+
# FIXME: This special logic shouldn't be needed for ObservedRVs
102+
# but PyMC does not allow one to not resample observed.
103+
# We hack around by conditioning on the value variable directly,
104+
# even though that should not be part of the generative graph...
105+
if isinstance(node.op, ModelObservedRV):
106+
init_dist = DiracDelta.dist(value[..., -op.ar_order :])
107+
else:
108+
cloned_model_rv = model_rv.owner.clone().default_output()
109+
fgraph.add_output(cloned_model_rv, import_missing=True)
110+
init_dist = DiracDelta.dist(cloned_model_rv[..., -op.ar_order :])
111+
112+
if isinstance(timeseries_rv.owner.op, AutoRegressiveRV):
113+
rhos, sigma, *_ = timeseries_rv.owner.inputs
114+
new_timeseries_rv = AR.rv_op(
115+
rhos=rhos,
116+
sigma=sigma,
117+
init_dist=init_dist,
118+
steps=forecast_steps,
119+
ar_order=op.ar_order,
120+
constant_term=op.constant_term,
121+
)
122+
123+
new_name = f"{model_rv.name}_forecast"
124+
new_value = new_timeseries_rv.type(name=new_name)
125+
new_timeseries_rv = model_free_rv(new_timeseries_rv, new_value, transform=None)
126+
new_timeseries_rv.name = new_name
127+
128+
# Import new variables into fgraph (value and RNG)
129+
fgraph.import_var(new_timeseries_rv, import_missing=True)
130+
131+
return [new_timeseries_rv]
132+
133+
134+
forecast_timeseries_rewrite = in2out(forecast_timeseries_node_rewrite, ignore_newtrees=True)
135+
136+
137+
def forecast_timeseries(model: Model, forecast_steps: int) -> Model:
138+
"""Replace timeseries variables in the model by forecast that start at the last value.
139+
140+
Replaced variables have the same name as original ones with an additional "_forecast" suffix.
141+
142+
The function will fail if any variables with fixed static shape depend on the timeseries being replaced,
143+
and forecast_steps differs from the original timeseries steps.
144+
145+
.. code-block:: python
146+
147+
import pymc as pm
148+
from pymc_experimental.model_transform import forecast_timeseries
149+
150+
with pm.Model() as model:
151+
rho = pm.Normal("rho")
152+
sigma = pm.HalfNormal("sigma")
153+
init_dist = pm.Normal.dist()
154+
y = pm.AR("y", init_dist=init_dist, rho=rho, sigma=sigma, observed=np.zeros(100,))
155+
idata = pm.sample()
156+
157+
forecast_model = forecast_timeseries(mode, forecast_steps=20)
158+
with forecast_model:
159+
idata_pp = pm.sample_posterior_predictive(idata, var_names=["y_forecast"])
160+
161+
az.summary(idata_pp)
162+
"""
163+
164+
fg = fgraph_from_model(model)
165+
166+
forecast_steps_sh = shared(forecast_steps, name="forecast_steps_")
167+
forecast_steps_sh = model_named(forecast_steps_sh)
168+
fg.add_output(forecast_steps_sh, import_missing=True)
169+
170+
(_, nodes_changed, *_) = forecast_timeseries_rewrite.apply(fg)
171+
if not nodes_changed:
172+
raise RuntimeError("No timeseries were replaced by forecast counterparts")
173+
174+
res = model_from_fgraph(fg)
175+
return res
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import arviz as az
2+
import numpy as np
3+
import pymc as pm
4+
import pytest
5+
6+
from pymc_experimental.model_transform import forecast_timeseries, uncensor
7+
8+
9+
@pytest.mark.parametrize(
10+
"transform, kwargs",
11+
[
12+
(uncensor, dict()),
13+
(forecast_timeseries, dict(forecast_steps=20)),
14+
],
15+
)
16+
def test_transform_error(transform, kwargs):
17+
"""Test informative error is raised when the transform is not applicable to a model."""
18+
with pm.Model() as model:
19+
x = pm.Normal("x")
20+
y = pm.Normal("y", x, observed=[0, 5, 10])
21+
22+
with pytest.raises(RuntimeError, match="No .* were replaced by .* counterparts"):
23+
transform(model, **kwargs)
24+
25+
26+
def test_uncensor():
27+
with pm.Model() as model:
28+
x = pm.Normal("x")
29+
dist_raw = pm.Normal.dist(x)
30+
y = pm.Censored("y", dist=dist_raw, lower=-1, upper=1, observed=[0, 5, 10])
31+
det = pm.Deterministic("det", y * 2)
32+
33+
idata = az.from_dict({"x": np.zeros((2, 500))})
34+
35+
with uncensor(model):
36+
pp = pm.sample_posterior_predictive(
37+
idata,
38+
var_names=["y_uncensored", "det"],
39+
random_seed=18,
40+
).posterior_predictive
41+
42+
assert np.any(np.abs(pp["y_uncensored"]) > 1)
43+
np.testing.assert_allclose(pp["y_uncensored"] * 2, pp["det"])
44+
45+
46+
@pytest.mark.parametrize("observed", (True, False))
47+
@pytest.mark.parametrize("ar_order", (1, 2))
48+
def test_forecast_timeseries_ar(observed, ar_order):
49+
data_steps = 3
50+
data = np.hstack((np.zeros(ar_order), (np.arange(data_steps) + 1) * 100.0))
51+
with pm.Model() as model:
52+
rho = pm.Normal("rho", shape=(ar_order,))
53+
sigma = pm.HalfNormal("sigma")
54+
init_dist = pm.Normal.dist(0, 1e-3)
55+
y = pm.AR(
56+
"y",
57+
init_dist=init_dist,
58+
rho=rho,
59+
sigma=sigma,
60+
observed=data if observed else None,
61+
steps=data_steps,
62+
)
63+
det = pm.Deterministic("det", y * 2)
64+
65+
draws = (2, 50)
66+
# These rhos mean that all steps will be data[-1] for ar_order > 1
67+
idata_dict = {
68+
"rho": np.full((*draws, ar_order), (0.1,) + (0,) * (ar_order - 1)),
69+
"sigma": np.full(draws, 1e-5),
70+
}
71+
if observed:
72+
idata = az.from_dict(idata_dict, observed_data={"y": data})
73+
else:
74+
idata_dict["y"] = np.full((*draws, len(data)), data)
75+
idata = az.from_dict(idata_dict)
76+
77+
forecast_steps = 5
78+
with forecast_timeseries(model, forecast_steps=forecast_steps):
79+
pp = pm.sample_posterior_predictive(
80+
idata,
81+
var_names=["y_forecast", "det"],
82+
random_seed=50,
83+
).posterior_predictive
84+
85+
expected = data[-1] / np.logspace(0, forecast_steps, forecast_steps + 1)
86+
expected = np.hstack((data[-ar_order:-1], expected))
87+
np.testing.assert_allclose(
88+
pp["y_forecast"].values, np.full((*draws, forecast_steps + ar_order), expected), rtol=0.01
89+
)
90+
np.testing.assert_allclose(pp["y_forecast"] * 2, pp["det"])

0 commit comments

Comments
 (0)