Skip to content

Commit 08a60ae

Browse files
committed
Return memo dictionary in fgraph_from_model
1 parent 3bb16fc commit 08a60ae

File tree

2 files changed

+47
-17
lines changed

2 files changed

+47
-17
lines changed

pymc_experimental/tests/utils/test_model_fgraph.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
from pytensor.tensor.exceptions import NotScalarConstantError
88

99
from pymc_experimental.utils.model_fgraph import (
10+
ModelDeterministic,
1011
ModelFreeRV,
12+
ModelNamed,
13+
ModelObservedRV,
14+
ModelPotential,
1115
ModelVar,
1216
fgraph_from_model,
1317
model_deterministic,
@@ -23,11 +27,17 @@ def test_basic():
2327
y = pm.Deterministic("y", x + 1)
2428
w = pm.HalfNormal("w", pm.math.exp(y))
2529
z = pm.Normal("z", y, w, observed=[0, 1, 2], dims=("test_dim",))
26-
pm.Potential("pot", x * 2)
30+
pot = pm.Potential("pot", x * 2)
2731

28-
m_fgraph = fgraph_from_model(m_old)
32+
m_fgraph, memo = fgraph_from_model(m_old)
2933
assert isinstance(m_fgraph, FunctionGraph)
3034

35+
assert isinstance(memo[x].owner.op, ModelFreeRV)
36+
assert isinstance(memo[y].owner.op, ModelDeterministic)
37+
assert isinstance(memo[w].owner.op, ModelFreeRV)
38+
assert isinstance(memo[z].owner.op, ModelObservedRV)
39+
assert isinstance(memo[pot].owner.op, ModelPotential)
40+
3141
m_new = model_from_fgraph(m_fgraph)
3242
assert isinstance(m_new, pm.Model)
3343

@@ -79,7 +89,12 @@ def test_data():
7989
mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",))
8090
obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",))
8191

82-
m_new = model_from_fgraph(fgraph_from_model(m_old))
92+
m_fgraph, memo = fgraph_from_model(m_old)
93+
assert isinstance(memo[x].owner.op, ModelNamed)
94+
assert isinstance(memo[y].owner.op, ModelNamed)
95+
assert isinstance(memo[b0].owner.op, ModelNamed)
96+
97+
m_new = model_from_fgraph(m_fgraph)
8398

8499
# ConstantData is preserved
85100
assert m_new["b0"].data == m_old["b0"].data
@@ -125,7 +140,7 @@ def test_deterministics():
125140
assert m["y"].owner.inputs[3] is m["mu"]
126141
assert m["y"].owner.inputs[4] is not m["sigma"]
127142

128-
fg = fgraph_from_model(m)
143+
fg, _ = fgraph_from_model(m)
129144

130145
# Check that no Deterministics are in graph of x to y and y to z
131146
x, y, z, det_mu, det_sigma, det_y_, det_y__ = fg.outputs
@@ -173,7 +188,7 @@ def test_sub_model_error():
173188
with pm.Model() as sub_m:
174189
y = pm.Normal("y", x)
175190

176-
nodes = [v for v in fgraph_from_model(m).toposort() if not isinstance(v.op, ModelVar)]
191+
nodes = [v for v in fgraph_from_model(m)[0].toposort() if not isinstance(v.op, ModelVar)]
177192
assert len(nodes) == 2
178193
assert isinstance(nodes[0].op, pm.Beta)
179194
assert isinstance(nodes[1].op, pm.Normal)
@@ -234,7 +249,7 @@ def test_fgraph_rewrite(non_centered_rewrite):
234249
subject_mean = pm.Normal("subject_mean", group_mean, group_std, dims=("subject",))
235250
obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10), dims=("subject",))
236251

237-
fg = fgraph_from_model(m_old)
252+
fg, _ = fgraph_from_model(m_old)
238253
non_centered_rewrite.apply(fg)
239254

240255
m_new = model_from_fgraph(fg)

pymc_experimental/utils/model_fgraph.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Sequence, Tuple
1+
from typing import Dict, Optional, Sequence, Tuple
22

33
import pytensor
44
from pymc.logprob.transforms import RVTransform
@@ -109,16 +109,19 @@ def local_remove_identity(fgraph, node):
109109
remove_identity_rewrite = out2in(local_remove_identity)
110110

111111

112-
def fgraph_from_model(model: Model) -> FunctionGraph:
112+
def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Variable]]:
113113
"""Convert Model to FunctionGraph.
114114

115-
Create a FunctionGraph that includes a copy of model variables, wrapped in dummy `ModelVar` Ops.
116-
117-
PyTensor rewrites can be used to transform the FunctionGraph.
115+
See: model_from_fgraph
118116

119-
It should be possible to reconstruct a valid PyMC model using `model_from_fgraph`.
117+
Returns
118+
-------
119+
fgraph: FunctionGraph
120+
FunctionGraph that includes a copy of model variables, wrapped in dummy `ModelVar` Ops.
121+
It should be possible to reconstruct a valid PyMC model using `model_from_fgraph`.
120122

121-
See: model_from_fgraph
123+
memo: Dict
124+
A dictionary mapping original model variables to the equivalent nodes in the fgraph.
122125
"""
123126

124127
if any(v is not None for v in model.rvs_to_initial_values.values()):
@@ -202,7 +205,19 @@ def fgraph_from_model(model: Model) -> FunctionGraph:
202205
new_var = var
203206
new_vars.append(new_var)
204207

205-
toposort_replace(fgraph, tuple(zip(vars, new_vars)))
208+
replacements = tuple(zip(vars, new_vars))
209+
toposort_replace(fgraph, replacements)
210+
211+
# Reference model vars in memo
212+
inverse_memo = {v: k for k, v in memo.items()}
213+
for var, model_var in replacements:
214+
if isinstance(
215+
model_var.owner is not None and model_var.owner.op, (ModelDeterministic, ModelNamed)
216+
):
217+
# Ignore extra identity that will be removed at the end
218+
var = var.owner.inputs[0]
219+
original_var = inverse_memo[var]
220+
memo[original_var] = model_var
206221

207222
# Remove value variable as outputs, now that they are graph inputs
208223
first_value_idx = len(fgraph.outputs) - len(value_vars)
@@ -212,7 +227,7 @@ def fgraph_from_model(model: Model) -> FunctionGraph:
212227
# Now that we have Deterministic dummy Ops, we remove the noisy `Identity`s from the graph
213228
remove_identity_rewrite.apply(fgraph)
214229

215-
return fgraph
230+
return fgraph, memo
216231

217232

218233
def model_from_fgraph(fgraph: FunctionGraph) -> Model:
@@ -282,7 +297,7 @@ def model_from_fgraph(fgraph: FunctionGraph) -> Model:
282297
return model
283298

284299

285-
def clone_model(model: Model) -> Model:
300+
def clone_model(model: Model) -> Tuple[Model]:
286301
"""Clone a PyMC model.
287302

288303
Recreates a PyMC model with clones of the original variables.
@@ -310,4 +325,4 @@ def clone_model(model: Model) -> Model:
310325
z = pm.Deterministic("z", clone_x + 1)
311326

312327
"""
313-
return model_from_fgraph(fgraph_from_model(model))
328+
return model_from_fgraph(fgraph_from_model(model)[0])

0 commit comments

Comments
 (0)