Skip to content

Commit 103025b

Browse files
committed
Implement utility to convert Model to and from FunctionGraph
1 parent b730449 commit 103025b

File tree

6 files changed

+626
-1
lines changed

6 files changed

+626
-1
lines changed

docs/api_reference.rst

+3
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,8 @@ Utils
4949
.. autosummary::
5050
:toctree: generated/
5151

52+
clone_model
5253
spline.bspline_interpolation
5354
prior.prior_from_idata
55+
model_fgraph.fgraph_from_model
56+
model_fgraph.model_from_fgraph

pymc_experimental/tests/utils/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
import numpy as np
2+
import pymc as pm
3+
import pytensor.tensor as pt
4+
import pytest
5+
from pytensor.graph import Constant, FunctionGraph, node_rewriter
6+
from pytensor.graph.rewriting.basic import in2out
7+
from pytensor.tensor.exceptions import NotScalarConstantError
8+
9+
from pymc_experimental.utils.model_fgraph import (
10+
ModelFreeRV,
11+
ModelVar,
12+
fgraph_from_model,
13+
model_deterministic,
14+
model_free_rv,
15+
model_from_fgraph,
16+
)
17+
18+
19+
def test_basic():
20+
"""Test we can convert from a PyMC Model to a FunctionGraph and back"""
21+
with pm.Model(coords={"test_dim": range(3)}) as m_old:
22+
x = pm.Normal("x")
23+
y = pm.Deterministic("y", x + 1)
24+
w = pm.HalfNormal("w", pm.math.exp(y))
25+
z = pm.Normal("z", y, w, observed=[0, 1, 2], dims=("test_dim",))
26+
pm.Potential("pot", x * 2)
27+
28+
m_fgraph = fgraph_from_model(m_old)
29+
assert isinstance(m_fgraph, FunctionGraph)
30+
31+
m_new = model_from_fgraph(m_fgraph)
32+
assert isinstance(m_new, pm.Model)
33+
34+
assert m_new.coords == {"test_dim": tuple(range(3))}
35+
assert m_new._dim_lengths["test_dim"].eval() == 3
36+
assert m_new.named_vars_to_dims == {"z": ["test_dim"]}
37+
38+
named_vars = {"x", "y", "w", "z", "pot"}
39+
assert set(m_new.named_vars) == named_vars
40+
for named_var in named_vars:
41+
assert m_new[named_var] is not m_old[named_var]
42+
for value_new, value_old in zip(m_new.rvs_to_values.values(), m_old.rvs_to_values.values()):
43+
# Constants are not cloned
44+
if not isinstance(value_new, Constant):
45+
assert value_new is not value_old
46+
assert m_new["x"] in m_new.free_RVs
47+
assert m_new["w"] in m_new.free_RVs
48+
assert m_new["y"] in m_new.deterministics
49+
assert m_new["z"] in m_new.observed_RVs
50+
assert m_new["pot"] in m_new.potentials
51+
assert m_new.rvs_to_transforms[m_new["x"]] is None
52+
assert m_new.rvs_to_transforms[m_new["w"]] is pm.distributions.transforms.log
53+
assert m_new.rvs_to_transforms[m_new["z"]] is None
54+
55+
# Test random
56+
new_y_draw, new_z_draw = pm.draw([m_new["y"], m_new["z"]], draws=5, random_seed=1)
57+
old_y_draw, old_z_draw = pm.draw([m_old["y"], m_old["z"]], draws=5, random_seed=1)
58+
np.testing.assert_array_equal(new_y_draw, old_y_draw)
59+
np.testing.assert_array_equal(new_z_draw, old_z_draw)
60+
61+
# Test logp
62+
ip = m_new.initial_point()
63+
np.testing.assert_equal(
64+
m_new.compile_logp()(ip),
65+
m_old.compile_logp()(ip),
66+
)
67+
68+
69+
def test_data():
70+
"""Test shared RNGs, MutableData, ConstantData and Dim lengths are handled correctly.
71+
72+
Everything should be preserved across new and old models, except for shared RNGs
73+
"""
74+
with pm.Model(coords_mutable={"test_dim": range(3)}) as m_old:
75+
x = pm.MutableData("x", [0.0, 1.0, 2.0], dims=("test_dim",))
76+
y = pm.MutableData("y", [10.0, 11.0, 12.0], dims=("test_dim",))
77+
b0 = pm.ConstantData("b0", 0.0)
78+
b1 = pm.Normal("b1")
79+
mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",))
80+
obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",))
81+
82+
m_new = model_from_fgraph(fgraph_from_model(m_old))
83+
84+
# ConstantData is preserved
85+
assert m_new["b0"].data == m_old["b0"].data
86+
87+
# Shared non-rng shared variables are preserved
88+
assert m_new["x"].container is x.container
89+
assert m_new["y"].container is y.container
90+
assert m_new.rvs_to_values[m_new["obs"]] is m_new["y"]
91+
92+
# Shared rng shared variables are not preserved
93+
m_new["b1"].owner.inputs[0].container is not m_old["b1"].owner.inputs[0].container
94+
95+
with m_old:
96+
pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)})
97+
98+
assert m_new.dim_lengths["test_dim"].eval() == 2
99+
np.testing.assert_array_almost_equal(pm.draw(m_new["x"]), [100.0, 200.0])
100+
101+
102+
def test_deterministics():
103+
"""Test handling of deterministics.
104+
105+
We don't want Deterministics in the middle of the FunctionGraph, as they would make rewrites cumbersome
106+
However we want them in the middle of Model.basic_RVs, so they display nicely in graphviz
107+
108+
There is one edge case that has to be considered, when a Deterministic is just a copy of a RV.
109+
In that case we don't bother to reintroduce it in between other Model.basic_RVs
110+
"""
111+
with pm.Model() as m:
112+
x = pm.Normal("x")
113+
mu = pm.Deterministic("mu", pm.math.abs(x))
114+
sigma = pm.math.exp(x)
115+
pm.Deterministic("sigma", sigma)
116+
y = pm.Normal("y", mu, sigma)
117+
# Special case where the Deterministic
118+
# is a direct view on another model variable
119+
y_ = pm.Deterministic("y_", y)
120+
# Just for kicks, make it a double one!
121+
y__ = pm.Deterministic("y__", y_)
122+
z = pm.Normal("z", y__)
123+
124+
# Deterministic mu is in the graph of x to y but not sigma
125+
assert m["y"].owner.inputs[3] is m["mu"]
126+
assert m["y"].owner.inputs[4] is not m["sigma"]
127+
128+
fg = fgraph_from_model(m)
129+
130+
# Check that no Deterministics are in graph of x to y and y to z
131+
x, y, z, det_mu, det_sigma, det_y_, det_y__ = fg.outputs
132+
# [Det(mu), Det(sigma)]
133+
mu = det_mu.owner.inputs[0]
134+
sigma = det_sigma.owner.inputs[0]
135+
# [FreeRV(y(mu, sigma))] not [FreeRV(y(Det(mu), Det(sigma)))]
136+
assert y.owner.inputs[0].owner.inputs[3] is mu
137+
assert y.owner.inputs[0].owner.inputs[4] is sigma
138+
# [FreeRV(z(y))] not [FreeRV(z(Det(Det(y))))]
139+
assert z.owner.inputs[0].owner.inputs[3] is y
140+
# [Det(y), Det(y)], not [Det(y), Det(Det(y))]
141+
assert det_y_.owner.inputs[0] is y
142+
assert det_y__.owner.inputs[0] is y
143+
assert det_y_ is not det_y__
144+
145+
# Both mu and sigma deterministics are now in the graph of x to y
146+
m = model_from_fgraph(fg)
147+
assert m["y"].owner.inputs[3] is m["mu"]
148+
assert m["y"].owner.inputs[4] is m["sigma"]
149+
# But not y_* in y to z, since there was no real Op in between
150+
assert m["z"].owner.inputs[3] is m["y"]
151+
assert m["y_"].owner.inputs[0] is m["y"]
152+
assert m["y__"].owner.inputs[0] is m["y"]
153+
154+
155+
def test_context_error():
156+
"""Test that model_from_fgraph fails when called inside a Model context.
157+
158+
We can't allow it, because the new Model that's returned would be a child of whatever Model context is active.
159+
"""
160+
with pm.Model() as m:
161+
x = pm.Normal("x")
162+
163+
fg = fgraph_from_model(m)
164+
165+
with pytest.raises(RuntimeError, match="cannot be called inside a PyMC model context"):
166+
model_from_fgraph(fg)
167+
168+
169+
def test_sub_model_error():
170+
"""Test Error is raised when trying to convert a sub-model to fgraph."""
171+
with pm.Model() as m:
172+
x = pm.Beta("x", 1, 1)
173+
with pm.Model() as sub_m:
174+
y = pm.Normal("y", x)
175+
176+
nodes = [v for v in fgraph_from_model(m).toposort() if not isinstance(v.op, ModelVar)]
177+
assert len(nodes) == 2
178+
assert isinstance(nodes[0].op, pm.Beta)
179+
assert isinstance(nodes[1].op, pm.Normal)
180+
181+
with pytest.raises(ValueError, match="Nested sub-models cannot be converted"):
182+
fgraph_from_model(sub_m)
183+
184+
185+
@pytest.fixture()
186+
def non_centered_rewrite():
187+
@node_rewriter(tracks=[ModelFreeRV])
188+
def non_centered_param(fgraph: FunctionGraph, node):
189+
"""Rewrite that replaces centered normal by non-centered parametrization."""
190+
191+
rv, value, *dims = node.inputs
192+
if not isinstance(rv.owner.op, pm.Normal):
193+
return
194+
rng, size, dtype, loc, scale = rv.owner.inputs
195+
196+
# Only apply rewrite if size information is explicit
197+
if size.ndim == 0:
198+
return None
199+
200+
try:
201+
is_unit = (
202+
pt.get_underlying_scalar_constant_value(loc) == 0
203+
and pt.get_underlying_scalar_constant_value(scale) == 1
204+
)
205+
except NotScalarConstantError:
206+
is_unit = False
207+
208+
# Nothing to do here
209+
if is_unit:
210+
return
211+
212+
raw_norm = pm.Normal.dist(0, 1, size=size, rng=rng)
213+
raw_norm.name = f"{rv.name}_raw_"
214+
raw_norm_value = raw_norm.clone()
215+
fgraph.add_input(raw_norm_value)
216+
raw_norm = model_free_rv(raw_norm, raw_norm_value, node.op.transform, *dims)
217+
218+
new_norm = loc + raw_norm * scale
219+
new_norm.name = rv.name
220+
new_norm_det = model_deterministic(new_norm, *dims)
221+
fgraph.add_output(new_norm_det)
222+
223+
return [new_norm]
224+
225+
return in2out(non_centered_param)
226+
227+
228+
def test_fgraph_rewrite(non_centered_rewrite):
229+
"""Test we can apply a simple rewrite to a PyMC Model."""
230+
231+
with pm.Model(coords={"subject": range(10)}) as m_old:
232+
group_mean = pm.Normal("group_mean")
233+
group_std = pm.HalfNormal("group_std")
234+
subject_mean = pm.Normal("subject_mean", group_mean, group_std, dims=("subject",))
235+
obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10), dims=("subject",))
236+
237+
fg = fgraph_from_model(m_old)
238+
non_centered_rewrite.apply(fg)
239+
240+
m_new = model_from_fgraph(fg)
241+
assert m_new.named_vars_to_dims == {
242+
"subject_mean": ["subject"],
243+
"subject_mean_raw_": ["subject"],
244+
"obs": ["subject"],
245+
}
246+
assert set(m_new.named_vars) == {
247+
"group_mean",
248+
"group_std",
249+
"subject_mean_raw_",
250+
"subject_mean",
251+
"obs",
252+
}
253+
assert {rv.name for rv in m_new.free_RVs} == {"group_mean", "group_std", "subject_mean_raw_"}
254+
assert {rv.name for rv in m_new.observed_RVs} == {"obs"}
255+
assert {rv.name for rv in m_new.deterministics} == {"subject_mean"}
256+
257+
with pm.Model() as m_ref:
258+
group_mean = pm.Normal("group_mean")
259+
group_std = pm.HalfNormal("group_std")
260+
subject_mean_raw = pm.Normal("subject_mean_raw_", 0, 1, shape=(10,))
261+
subject_mean = pm.Deterministic("subject_mean", group_mean + subject_mean_raw * group_std)
262+
obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10))
263+
264+
np.testing.assert_array_equal(
265+
pm.draw(m_new["subject_mean_raw_"], draws=7, random_seed=1),
266+
pm.draw(m_ref["subject_mean_raw_"], draws=7, random_seed=1),
267+
)
268+
269+
ip = m_new.initial_point()
270+
np.testing.assert_equal(
271+
m_new.compile_logp()(ip),
272+
m_ref.compile_logp()(ip),
273+
)

pymc_experimental/utils/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,11 @@
1515

1616
from pymc_experimental.utils import prior, spline
1717
from pymc_experimental.utils.linear_cg import linear_cg
18+
from pymc_experimental.utils.model_fgraph import clone_model
1819

19-
# from pymc_experimental.utils.pivoted_cholesky import pivoted_cholesky
20+
__all__ = (
21+
"clone_model",
22+
"linear_cg",
23+
"prior",
24+
"spline",
25+
)

0 commit comments

Comments
 (0)