7
7
from pytensor.tensor.exceptions import NotScalarConstantError
8
8
9
9
from pymc_experimental.utils.model_fgraph import (
10
+ ModelDeterministic,
10
11
ModelFreeRV,
12
+ ModelNamed,
13
+ ModelObservedRV,
14
+ ModelPotential,
11
15
ModelVar,
12
16
fgraph_from_model,
13
17
model_deterministic,
@@ -23,11 +27,17 @@ def test_basic():
23
27
y = pm.Deterministic("y", x + 1)
24
28
w = pm.HalfNormal("w", pm.math.exp(y))
25
29
z = pm.Normal("z", y, w, observed=[0, 1, 2], dims=("test_dim",))
26
- pm.Potential("pot", x * 2)
30
+ pot = pm.Potential("pot", x * 2)
27
31
28
- m_fgraph = fgraph_from_model(m_old)
32
+ m_fgraph, memo = fgraph_from_model(m_old)
29
33
assert isinstance(m_fgraph, FunctionGraph)
30
34
35
+ assert isinstance(memo[x].owner.op, ModelFreeRV)
36
+ assert isinstance(memo[y].owner.op, ModelDeterministic)
37
+ assert isinstance(memo[w].owner.op, ModelFreeRV)
38
+ assert isinstance(memo[z].owner.op, ModelObservedRV)
39
+ assert isinstance(memo[pot].owner.op, ModelPotential)
40
+
31
41
m_new = model_from_fgraph(m_fgraph)
32
42
assert isinstance(m_new, pm.Model)
33
43
@@ -79,7 +89,12 @@ def test_data():
79
89
mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",))
80
90
obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",))
81
91
82
- m_new = model_from_fgraph(fgraph_from_model(m_old))
92
+ m_fgraph, memo = fgraph_from_model(m_old)
93
+ assert isinstance(memo[x].owner.op, ModelNamed)
94
+ assert isinstance(memo[y].owner.op, ModelNamed)
95
+ assert isinstance(memo[b0].owner.op, ModelNamed)
96
+
97
+ m_new = model_from_fgraph(m_fgraph)
83
98
84
99
# ConstantData is preserved
85
100
assert m_new["b0"].data == m_old["b0"].data
@@ -125,7 +140,7 @@ def test_deterministics():
125
140
assert m["y"].owner.inputs[3] is m["mu"]
126
141
assert m["y"].owner.inputs[4] is not m["sigma"]
127
142
128
- fg = fgraph_from_model(m)
143
+ fg, _ = fgraph_from_model(m)
129
144
130
145
# Check that no Deterministics are in graph of x to y and y to z
131
146
x, y, z, det_mu, det_sigma, det_y_, det_y__ = fg.outputs
@@ -173,7 +188,7 @@ def test_sub_model_error():
173
188
with pm.Model() as sub_m:
174
189
y = pm.Normal("y", x)
175
190
176
- nodes = [v for v in fgraph_from_model(m).toposort() if not isinstance(v.op, ModelVar)]
191
+ nodes = [v for v in fgraph_from_model(m)[0] .toposort() if not isinstance(v.op, ModelVar)]
177
192
assert len(nodes) == 2
178
193
assert isinstance(nodes[0].op, pm.Beta)
179
194
assert isinstance(nodes[1].op, pm.Normal)
@@ -234,7 +249,7 @@ def test_fgraph_rewrite(non_centered_rewrite):
234
249
subject_mean = pm.Normal("subject_mean", group_mean, group_std, dims=("subject",))
235
250
obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10), dims=("subject",))
236
251
237
- fg = fgraph_from_model(m_old)
252
+ fg, _ = fgraph_from_model(m_old)
238
253
non_centered_rewrite.apply(fg)
239
254
240
255
m_new = model_from_fgraph(fg)
0 commit comments