|
| 1 | +from pymc import DiracDelta |
| 2 | +from pymc.distributions.censored import CensoredRV |
| 3 | +from pymc.distributions.timeseries import AR, AutoRegressiveRV |
| 4 | +from pymc.model import Model |
| 5 | +from pytensor import shared |
| 6 | +from pytensor.graph import FunctionGraph, node_rewriter |
| 7 | +from pytensor.graph.basic import get_var_by_name |
| 8 | +from pytensor.graph.rewriting.basic import in2out |
| 9 | + |
| 10 | +from pymc_experimental.utils.model_fgraph import ( |
| 11 | + ModelObservedRV, |
| 12 | + ModelValuedVar, |
| 13 | + fgraph_from_model, |
| 14 | + model_free_rv, |
| 15 | + model_from_fgraph, |
| 16 | + model_named, |
| 17 | +) |
| 18 | + |
| 19 | +__all__ = ( |
| 20 | + "forecast_timeseries", |
| 21 | + "uncensor", |
| 22 | +) |
| 23 | + |
| 24 | + |
| 25 | +@node_rewriter(tracks=[ModelValuedVar]) |
| 26 | +def uncensor_node_rewrite(fgraph, node): |
| 27 | + """Rewrite that replaces censored variables by uncensored ones""" |
| 28 | + |
| 29 | + ( |
| 30 | + censored_rv, |
| 31 | + value, |
| 32 | + *dims, |
| 33 | + ) = node.inputs |
| 34 | + if not isinstance(censored_rv.owner.op, CensoredRV): |
| 35 | + return |
| 36 | + |
| 37 | + model_rv = node.outputs[0] |
| 38 | + base_rv = censored_rv.owner.inputs[0] |
| 39 | + uncensored_rv = node.op.make_node(base_rv, value, *dims).default_output() |
| 40 | + uncensored_rv.name = f"{model_rv.name}_uncensored" |
| 41 | + return [uncensored_rv] |
| 42 | + |
| 43 | + |
| 44 | +uncensor_rewrite = in2out(uncensor_node_rewrite) |
| 45 | + |
| 46 | + |
| 47 | +def uncensor(model: Model) -> Model: |
| 48 | + """Replace censored variables in the model by uncensored equivalent. |
| 49 | +
|
| 50 | + Replaced variables have the same name as original ones with an additional "_uncensored" suffix. |
| 51 | +
|
| 52 | + .. code-block:: python |
| 53 | +
|
| 54 | + import arviz as az |
| 55 | + import pymc as pm |
| 56 | + from pymc_experimental.model_transform import uncensor |
| 57 | +
|
| 58 | + with pm.Model() as model: |
| 59 | + x = pm.Normal("x") |
| 60 | + dist_raw = pm.Normal.dist(x) |
| 61 | + y = pm.Censored("y", dist=dist_raw, lower=-1, upper=1, observed=[-1, 0.5, 1, 1, 1]) |
| 62 | + idata = pm.sample() |
| 63 | +
|
| 64 | + with uncensor(model): |
| 65 | + idata_pp = pm.sample_posterior_predictive(idata, var_names=["y_uncensored"]) |
| 66 | +
|
| 67 | + az.summary(idata_pp) |
| 68 | + """ |
| 69 | + fg = fgraph_from_model(model) |
| 70 | + |
| 71 | + (_, nodes_changed, *_) = uncensor_rewrite.apply(fg) |
| 72 | + if not nodes_changed: |
| 73 | + raise RuntimeError("No censored variables were replaced by uncensored counterparts") |
| 74 | + |
| 75 | + return model_from_fgraph(fg) |
| 76 | + |
| 77 | + |
| 78 | +@node_rewriter(tracks=[ModelValuedVar]) |
| 79 | +def forecast_timeseries_node_rewrite(fgraph: FunctionGraph, node): |
| 80 | + """Rewrite that replaces timeseries variables by new ones starting at the last timepoint(s).""" |
| 81 | + |
| 82 | + ( |
| 83 | + timeseries_rv, |
| 84 | + value, |
| 85 | + *dims, |
| 86 | + ) = node.inputs |
| 87 | + if not isinstance(timeseries_rv.owner.op, AutoRegressiveRV): |
| 88 | + return |
| 89 | + |
| 90 | + forecast_steps = get_var_by_name(fgraph.inputs, "forecast_steps_") |
| 91 | + if len(forecast_steps) != 1: |
| 92 | + return False |
| 93 | + |
| 94 | + forecast_steps = forecast_steps[0] |
| 95 | + |
| 96 | + op = timeseries_rv.owner.op |
| 97 | + model_rv = node.outputs[0] |
| 98 | + |
| 99 | + # We cannot reference the variable we are planning to replace |
| 100 | + # Or it will introduce circularities in the graph |
| 101 | + # FIXME: This special logic shouldn't be needed for ObservedRVs |
| 102 | + # but PyMC does not allow one to not resample observed. |
| 103 | + # We hack around by conditioning on the value variable directly, |
| 104 | + # even though that should not be part of the generative graph... |
| 105 | + if isinstance(node.op, ModelObservedRV): |
| 106 | + init_dist = DiracDelta.dist(value[..., -op.ar_order :]) |
| 107 | + else: |
| 108 | + cloned_model_rv = model_rv.owner.clone().default_output() |
| 109 | + fgraph.add_output(cloned_model_rv, import_missing=True) |
| 110 | + init_dist = DiracDelta.dist(cloned_model_rv[..., -op.ar_order :]) |
| 111 | + |
| 112 | + if isinstance(timeseries_rv.owner.op, AutoRegressiveRV): |
| 113 | + rhos, sigma, *_ = timeseries_rv.owner.inputs |
| 114 | + new_timeseries_rv = AR.rv_op( |
| 115 | + rhos=rhos, |
| 116 | + sigma=sigma, |
| 117 | + init_dist=init_dist, |
| 118 | + steps=forecast_steps, |
| 119 | + ar_order=op.ar_order, |
| 120 | + constant_term=op.constant_term, |
| 121 | + ) |
| 122 | + |
| 123 | + new_name = f"{model_rv.name}_forecast" |
| 124 | + new_value = new_timeseries_rv.type(name=new_name) |
| 125 | + new_timeseries_rv = model_free_rv(new_timeseries_rv, new_value, transform=None) |
| 126 | + new_timeseries_rv.name = new_name |
| 127 | + |
| 128 | + # Import new variables into fgraph (value and RNG) |
| 129 | + fgraph.import_var(new_timeseries_rv, import_missing=True) |
| 130 | + |
| 131 | + return [new_timeseries_rv] |
| 132 | + |
| 133 | + |
| 134 | +forecast_timeseries_rewrite = in2out(forecast_timeseries_node_rewrite, ignore_newtrees=True) |
| 135 | + |
| 136 | + |
| 137 | +def forecast_timeseries(model: Model, forecast_steps: int) -> Model: |
| 138 | + """Replace timeseries variables in the model by forecast that start at the last value. |
| 139 | +
|
| 140 | + Replaced variables have the same name as original ones with an additional "_forecast" suffix. |
| 141 | +
|
| 142 | + The function will fail if any variables with fixed static shape depend on the timeseries being replaced, |
| 143 | + and forecast_steps differs from the original timeseries steps. |
| 144 | +
|
| 145 | + .. code-block:: python |
| 146 | +
|
| 147 | + import pymc as pm |
| 148 | + from pymc_experimental.model_transform import forecast_timeseries |
| 149 | +
|
| 150 | + with pm.Model() as model: |
| 151 | + rho = pm.Normal("rho") |
| 152 | + sigma = pm.HalfNormal("sigma") |
| 153 | + init_dist = pm.Normal.dist() |
| 154 | + y = pm.AR("y", init_dist=init_dist, rho=rho, sigma=sigma, observed=np.zeros(100,)) |
| 155 | + idata = pm.sample() |
| 156 | +
|
| 157 | + forecast_model = forecast_timeseries(mode, forecast_steps=20) |
| 158 | + with forecast_model: |
| 159 | + idata_pp = pm.sample_posterior_predictive(idata, var_names=["y_forecast"]) |
| 160 | +
|
| 161 | + az.summary(idata_pp) |
| 162 | + """ |
| 163 | + |
| 164 | + fg = fgraph_from_model(model) |
| 165 | + |
| 166 | + forecast_steps_sh = shared(forecast_steps, name="forecast_steps_") |
| 167 | + forecast_steps_sh = model_named(forecast_steps_sh) |
| 168 | + fg.add_output(forecast_steps_sh, import_missing=True) |
| 169 | + |
| 170 | + (_, nodes_changed, *_) = forecast_timeseries_rewrite.apply(fg) |
| 171 | + if not nodes_changed: |
| 172 | + raise RuntimeError("No timeseries were replaced by forecast counterparts") |
| 173 | + |
| 174 | + res = model_from_fgraph(fg) |
| 175 | + return res |
0 commit comments