|
5 | 5 | from pytensor.tensor import TensorVariable
|
6 | 6 |
|
7 | 7 | from pymc_experimental.utils.model_fgraph import (
|
| 8 | + ModelDeterministic, |
8 | 9 | ModelFreeRV,
|
9 | 10 | extract_dims,
|
10 | 11 | fgraph_from_model,
|
|
16 | 17 |
|
17 | 18 |
|
18 | 19 | def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model:
|
19 |
| - """Convert free RVs to observed RVs. |
| 20 | + """Convert free RVs or Deterministics to observed RVs. |
20 | 21 |
|
21 | 22 | Parameters
|
22 | 23 | ----------
|
@@ -47,29 +48,48 @@ def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable
|
47 | 48 |
|
48 | 49 | m_new = observe(m, {y: 0.5})
|
49 | 50 |
|
| 51 | + Deterministic variables can also be observed. |
| 52 | + This relies on PyMC ability to infer the logp of the underlying expression |
| 53 | +
|
| 54 | + .. code-block:: python |
| 55 | +
|
| 56 | + import pymc as pm |
| 57 | + from pymc_experimental.model_transform.conditioning import observe |
| 58 | +
|
| 59 | + with pm.Model() as m: |
| 60 | + x = pm.Normal("x") |
| 61 | + y = pm.Normal.dist(x, shape=(5,)) |
| 62 | + y_censored = pm.Deterministic("y_censored", pm.math.clip(y, -1, 1)) |
| 63 | +
|
| 64 | + new_m = observe(m, {y_censored: [0.9, 0.5, 0.3, 1, 1]}) |
| 65 | +
|
| 66 | +
|
50 | 67 | """
|
51 | 68 | vars_to_observations = {
|
52 | 69 | model[var] if isinstance(var, str) else var: obs
|
53 | 70 | for var, obs in vars_to_observations.items()
|
54 | 71 | }
|
55 | 72 |
|
56 | 73 | # Note: Since PyMC can infer logprob expressions we could also allow observing Deterministics
|
57 |
| - if any(var not in model.free_RVs for var in vars_to_observations): |
58 |
| - raise ValueError(f"At least one var is not a free variable in the model") |
| 74 | + valid_model_vars = set(model.free_RVs + model.deterministics) |
| 75 | + if any(var not in valid_model_vars for var in vars_to_observations): |
| 76 | + raise ValueError(f"At least one var is not a free variable or deterministic in the model") |
59 | 77 |
|
60 | 78 | fgraph, memo = fgraph_from_model(model)
|
61 | 79 |
|
62 | 80 | replacements = {}
|
63 | 81 | for var, obs in vars_to_observations.items():
|
64 |
| - model_free_rv = memo[var] |
| 82 | + model_var = memo[var] |
65 | 83 |
|
66 | 84 | # Just a sanity check
|
67 |
| - assert isinstance(model_free_rv.owner.op, ModelFreeRV) |
68 |
| - assert model_free_rv in fgraph.variables |
| 85 | + assert isinstance(model_var.owner.op, (ModelFreeRV, ModelDeterministic)) |
| 86 | + assert model_var in fgraph.variables |
69 | 87 |
|
70 |
| - rv, vv, *dims = model_free_rv.owner.inputs |
71 |
| - model_obs_rv = model_observed_rv(rv, rv.type.filter_variable(obs), *dims) |
72 |
| - replacements[model_free_rv] = model_obs_rv |
| 88 | + var = model_var.owner.inputs[0] |
| 89 | + var.name = model_var.name |
| 90 | + dims = extract_dims(var) |
| 91 | + model_obs_rv = model_observed_rv(var, var.type.filter_variable(obs), *dims) |
| 92 | + replacements[model_var] = model_obs_rv |
73 | 93 |
|
74 | 94 | toposort_replace(fgraph, tuple(replacements.items()))
|
75 | 95 |
|
|
0 commit comments