Skip to content

Commit 1a0eedf

Browse files
committed
Expand observe transformation to Deterministics
1 parent cec2963 commit 1a0eedf

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

pymc_experimental/model_transform/conditioning.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytensor.tensor import TensorVariable
66

77
from pymc_experimental.utils.model_fgraph import (
8+
ModelDeterministic,
89
ModelFreeRV,
910
extract_dims,
1011
fgraph_from_model,
@@ -16,7 +17,7 @@
1617

1718

1819
def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model:
19-
"""Convert free RVs to observed RVs.
20+
"""Convert free RVs or Deterministics to observed RVs.
2021
2122
Parameters
2223
----------
@@ -47,29 +48,48 @@ def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable
4748
4849
m_new = observe(m, {y: 0.5})
4950
51+
Deterministic variables can also be observed.
52+
This relies on PyMC ability to infer the logp of the underlying expression
53+
54+
.. code-block:: python
55+
56+
import pymc as pm
57+
from pymc_experimental.model_transform.conditioning import observe
58+
59+
with pm.Model() as m:
60+
x = pm.Normal("x")
61+
y = pm.Normal.dist(x, shape=(5,))
62+
y_censored = pm.Deterministic("y_censored", pm.math.clip(y, -1, 1))
63+
64+
new_m = observe(m, {y_censored: [0.9, 0.5, 0.3, 1, 1]})
65+
66+
5067
"""
5168
vars_to_observations = {
5269
model[var] if isinstance(var, str) else var: obs
5370
for var, obs in vars_to_observations.items()
5471
}
5572

5673
# Note: Since PyMC can infer logprob expressions we could also allow observing Deterministics
57-
if any(var not in model.free_RVs for var in vars_to_observations):
58-
raise ValueError(f"At least one var is not a free variable in the model")
74+
valid_model_vars = set(model.free_RVs + model.deterministics)
75+
if any(var not in valid_model_vars for var in vars_to_observations):
76+
raise ValueError(f"At least one var is not a free variable or deterministic in the model")
5977

6078
fgraph, memo = fgraph_from_model(model)
6179

6280
replacements = {}
6381
for var, obs in vars_to_observations.items():
64-
model_free_rv = memo[var]
82+
model_var = memo[var]
6583

6684
# Just a sanity check
67-
assert isinstance(model_free_rv.owner.op, ModelFreeRV)
68-
assert model_free_rv in fgraph.variables
85+
assert isinstance(model_var.owner.op, (ModelFreeRV, ModelDeterministic))
86+
assert model_var in fgraph.variables
6987

70-
rv, vv, *dims = model_free_rv.owner.inputs
71-
model_obs_rv = model_observed_rv(rv, rv.type.filter_variable(obs), *dims)
72-
replacements[model_free_rv] = model_obs_rv
88+
var = model_var.owner.inputs[0]
89+
var.name = model_var.name
90+
dims = extract_dims(var)
91+
model_obs_rv = model_observed_rv(var, var.type.filter_variable(obs), *dims)
92+
replacements[model_var] = model_obs_rv
7393

7494
toposort_replace(fgraph, tuple(replacements.items()))
7595

pymc_experimental/tests/model_transform/test_conditioning.py

+20
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,26 @@ def test_observe_minibatch():
6666
)
6767

6868

69+
def test_observe_deterministic():
70+
y_censored_obs = [0.9, 0.5, 0.3, 1, 1]
71+
72+
with pm.Model() as m_old:
73+
x = pm.Normal("x")
74+
y = pm.Normal.dist(x, shape=(5,))
75+
y_censored = pm.Deterministic("y_censored", pm.math.clip(y, -1, 1))
76+
77+
m_new = observe(m_old, {y_censored: y_censored_obs})
78+
79+
with pm.Model() as m_ref:
80+
x = pm.Normal("x")
81+
pm.Censored("y_censored", pm.Normal.dist(x), lower=-1, upper=1, observed=y_censored_obs)
82+
83+
np.testing.assert_allclose(
84+
m_new.compile_logp()({"x": 0.9}),
85+
m_ref.compile_logp()({"x": 0.9}),
86+
)
87+
88+
6989
def test_do():
7090
with pm.Model() as m_old:
7191
x = pm.Normal("x", 0, 1e-3)

0 commit comments

Comments
 (0)