Skip to content

Commit 14d37ec

Browse files
committed
Add option to prune variables after do intervention
1 parent f662d82 commit 14d37ec

File tree

4 files changed

+92
-2
lines changed

4 files changed

+92
-2
lines changed
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from pymc import Model
2+
from pytensor.graph import ancestors
3+
4+
from pymc_experimental.utils.model_fgraph import (
5+
ModelObservedRV,
6+
ModelVar,
7+
fgraph_from_model,
8+
model_from_fgraph,
9+
)
10+
11+
12+
def prune_vars_detached_from_observed(model: Model) -> Model:
13+
"""Prune model variables that are not related to any observed variable in the Model."""
14+
15+
# Potentials are ambiguous as whether they correspond to likelihood or prior terms,
16+
# We simply raise for now
17+
if model.potentials:
18+
raise NotImplementedError("Pruning not implemented for models with Potentials")
19+
20+
fgraph, _ = fgraph_from_model(model, inlined_views=True)
21+
observed_vars = (
22+
out
23+
for node in fgraph.apply_nodes
24+
if isinstance(node.op, ModelObservedRV)
25+
for out in node.outputs
26+
)
27+
ancestor_nodes = {var.owner for var in ancestors(observed_vars)}
28+
nodes_to_remove = {
29+
node
30+
for node in fgraph.apply_nodes
31+
if isinstance(node.op, ModelVar) and node not in ancestor_nodes
32+
}
33+
for node_to_remove in nodes_to_remove:
34+
fgraph.remove_node(node_to_remove)
35+
return model_from_fgraph(fgraph)

pymc_experimental/model_transform/conditioning.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pymc.pytensorf import _replace_vars_in_graphs
55
from pytensor.tensor import TensorVariable
66

7+
from pymc_experimental.model_transform.basic import prune_vars_detached_from_observed
78
from pymc_experimental.utils.model_fgraph import (
89
ModelDeterministic,
910
ModelFreeRV,
@@ -113,7 +114,9 @@ def replacement_fn(var, inner_replacements):
113114
return replaced_graphs
114115

115116

116-
def do(model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any]) -> Model:
117+
def do(
118+
model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any], prune_vars=False
119+
) -> Model:
117120
"""Replace model variables by intervention variables.
118121
119122
Intervention variables will either show up as `Data` or `Deterministics` in the new model,
@@ -126,6 +129,9 @@ def do(model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], A
126129
Dictionary that maps model variables (or names) to intervention expressions.
127130
Intervention expressions must have a shape and data type that is compatible
128131
with the original model variable.
132+
prune_vars: bool, defaults to False
133+
Whether to prune model variables that are not connected to any observed variables,
134+
after the interventions.
129135
130136
Returns
131137
-------
@@ -196,4 +202,7 @@ def do(model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], A
196202
# Replace variables by interventions
197203
toposort_replace(fgraph, tuple(replacements.items()))
198204

199-
return model_from_fgraph(fgraph)
205+
model = model_from_fgraph(fgraph)
206+
if prune_vars:
207+
return prune_vars_detached_from_observed(model)
208+
return model
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pymc as pm
2+
3+
from pymc_experimental.model_transform.basic import prune_vars_detached_from_observed
4+
5+
6+
def test_prune_vars_detached_from_observed():
7+
with pm.Model() as m:
8+
obs_data = pm.MutableData("obs_data", 0)
9+
a0 = pm.ConstantData("a0", 0)
10+
a1 = pm.Normal("a1", a0)
11+
a2 = pm.Normal("a2", a1)
12+
pm.Normal("obs", a2, observed=obs_data)
13+
14+
d0 = pm.ConstantData("d0", 0)
15+
d1 = pm.Normal("d1", d0)
16+
17+
assert set(m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs", "d0", "d1"}
18+
pruned_m = prune_vars_detached_from_observed(m)
19+
assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"}

pymc_experimental/tests/model_transform/test_conditioning.py

+27
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,30 @@ def test_do_dims():
192192
},
193193
)
194194
assert do_m.named_vars_to_dims["y"] == ["test_dim"]
195+
196+
197+
@pytest.mark.parametrize("prune", (False, True))
198+
def test_do_prune(prune):
199+
200+
with pm.Model() as m:
201+
x0 = pm.ConstantData("x0", 0)
202+
x1 = pm.ConstantData("x1", 0)
203+
y = pm.Normal("y")
204+
y_det = pm.Deterministic("y_det", y + x0)
205+
z = pm.Normal("z", y_det)
206+
llike = pm.Normal("llike", z + x1, observed=0)
207+
208+
orig_named_vars = {"x0", "x1", "y", "y_det", "z", "llike"}
209+
assert set(m.named_vars) == orig_named_vars
210+
211+
do_m = do(m, {y_det: x0 + 5}, prune_vars=prune)
212+
if prune:
213+
assert set(do_m.named_vars) == {"x0", "x1", "y_det", "z", "llike"}
214+
else:
215+
assert set(do_m.named_vars) == orig_named_vars
216+
217+
do_m = do(m, {z: 0.5}, prune_vars=prune)
218+
if prune:
219+
assert set(do_m.named_vars) == {"x1", "z", "llike"}
220+
else:
221+
assert set(do_m.named_vars) == orig_named_vars

0 commit comments

Comments
 (0)