Skip to content

Commit ba7968b

Browse files
committed
Implement utility to convert PyMC modelt to and from FunctionGraph
1 parent 6f67dec commit ba7968b

File tree

4 files changed

+345
-0
lines changed

4 files changed

+345
-0
lines changed

pymc_experimental/tests/utils/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import numpy as np
2+
import pymc as pm
3+
import pytensor.tensor as pt
4+
import pytest
5+
from pytensor.graph import FunctionGraph, node_rewriter
6+
from pytensor.graph.rewriting.basic import in2out
7+
from pytensor.tensor.exceptions import NotScalarConstantError
8+
9+
from pymc_experimental.utils.model_fgraph import (
10+
FreeRV,
11+
deterministic,
12+
fgraph_from_model,
13+
free_rv,
14+
model_from_fgraph,
15+
)
16+
17+
18+
def test_model_fgraph_conversion():
19+
with pm.Model(coords={"test_dim": range(3)}) as m_old:
20+
x = pm.Normal("x")
21+
y = pm.Deterministic("y", x + 1)
22+
w = pm.Normal("w", y)
23+
z = pm.Normal("z", y, observed=[0, 1, 2], dims=("test_dim",))
24+
pm.Potential("pot", x * 2)
25+
26+
m_fgraph = fgraph_from_model(m_old)
27+
assert isinstance(m_fgraph, FunctionGraph)
28+
29+
m_new = model_from_fgraph(m_fgraph)
30+
assert isinstance(m_new, pm.Model)
31+
32+
assert m_new.coords == {"test_dim": tuple(range(3))}
33+
assert m_new.named_vars_to_dims == {"z": ["test_dim"]}
34+
35+
named_vars = {"x", "y", "w", "z", "pot"}
36+
assert set(m_new.named_vars) == named_vars
37+
for named_var in named_vars:
38+
assert m_new[named_var] is not m_old[named_var]
39+
assert m_new["x"] in m_new.free_RVs
40+
assert m_new["w"] in m_new.free_RVs
41+
assert m_new["y"] in m_new.deterministics
42+
assert m_new["z"] in m_new.observed_RVs
43+
assert m_new["pot"] in m_new.potentials
44+
45+
new_y_draw, new_z_draw = pm.draw([m_new["y"], m_new["z"]], draws=5, random_seed=1)
46+
old_y_draw, old_z_draw = pm.draw([m_old["y"], m_old["z"]], draws=5, random_seed=1)
47+
np.testing.assert_array_equal(new_y_draw, old_y_draw)
48+
np.testing.assert_array_equal(new_z_draw, old_z_draw)
49+
50+
ip = m_new.initial_point()
51+
np.testing.assert_equal(
52+
m_new.compile_logp()(ip),
53+
m_old.compile_logp()(ip),
54+
)
55+
56+
57+
@pytest.fixture()
58+
def non_centered_rewrite():
59+
@node_rewriter(tracks=[FreeRV])
60+
def non_centered_param(fgraph: FunctionGraph, node):
61+
"""Rewrite that replaces centered normal by non-centered parametrization."""
62+
63+
rv, _, *dims = node.inputs
64+
if not isinstance(rv.owner.op, pm.Normal):
65+
return
66+
rng, size, dtype, loc, scale = rv.owner.inputs
67+
68+
# Only apply rewrite if size information is explicit
69+
if size.ndim == 0:
70+
return None
71+
72+
try:
73+
is_unit = (
74+
pt.get_scalar_constant_value(loc) == 0 and pt.get_scalar_constant_value(scale) == 1
75+
)
76+
except NotScalarConstantError:
77+
is_unit = False
78+
79+
# Nothing to do here
80+
if is_unit:
81+
return
82+
83+
raw_norm = pm.Normal.dist(0, 1, size=size, rng=rng)
84+
raw_norm.name = f"{rv.name}_raw_"
85+
raw_norm_value = raw_norm.clone()
86+
fgraph.add_input(raw_norm_value)
87+
raw_norm = free_rv(raw_norm, raw_norm_value, dims=dims)
88+
89+
new_norm = loc + raw_norm * scale
90+
new_norm.name = rv.name
91+
new_norm = deterministic(new_norm, dims=dims)
92+
93+
return [new_norm]
94+
95+
return in2out(non_centered_param)
96+
97+
98+
def test_fgraph_rewrite(non_centered_rewrite):
99+
100+
with pm.Model(coords={"subject": range(10)}) as m_old:
101+
group_mean = pm.Normal("group_mean")
102+
# FIXME: Transforms are not yet maintained across conversion
103+
group_std = pm.HalfNormal("group_std", transform=None)
104+
subject_mean = pm.Normal("subject_mean", group_mean, group_std, dims=("subject",))
105+
obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10), dims=("subject",))
106+
107+
fg = fgraph_from_model(m_old)
108+
non_centered_rewrite.apply(fg)
109+
110+
m_new = model_from_fgraph(fg)
111+
assert m_new.named_vars_to_dims == {
112+
"subject_mean": ["subject"],
113+
"subject_mean_raw_": ["subject"],
114+
"obs": ["subject"],
115+
}
116+
assert set(m_new.named_vars) == {
117+
"group_mean",
118+
"group_std",
119+
"subject_mean_raw_",
120+
"subject_mean",
121+
"obs",
122+
}
123+
assert {rv.name for rv in m_new.free_RVs} == {"group_mean", "group_std", "subject_mean_raw_"}
124+
assert {rv.name for rv in m_new.observed_RVs} == {"obs"}
125+
assert {rv.name for rv in m_new.deterministics} == {"subject_mean"}
126+
127+
with pm.Model() as m_ref:
128+
group_mean = pm.Normal("group_mean")
129+
# FIXME: Transforms are not yet maintained across conversion
130+
group_std = pm.HalfNormal("group_std", transform=None)
131+
subject_mean_raw = pm.Normal("subject_mean_raw_", 0, 1, shape=(10,))
132+
subject_mean = pm.Deterministic("subject_mean", group_mean + subject_mean_raw * group_std)
133+
obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10))
134+
135+
np.testing.assert_array_equal(
136+
pm.draw(m_new["subject_mean_raw_"], draws=7, random_seed=1),
137+
pm.draw(m_ref["subject_mean_raw_"], draws=7, random_seed=1),
138+
)
139+
140+
ip = m_new.initial_point()
141+
np.testing.assert_equal(
142+
m_new.compile_logp()(ip),
143+
m_ref.compile_logp()(ip),
144+
)
+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from typing import Optional, Sequence
2+
3+
import pytensor
4+
from pymc.model import Model
5+
from pytensor.graph import Apply, FunctionGraph, Op
6+
from pytensor.tensor import TensorVariable
7+
8+
from pymc_experimental.utils.pytensorf import StringType
9+
10+
11+
class ModelVar(Op):
12+
"""An Op that binds together a RV and its value"""
13+
14+
def make_node(self, rv, value=None, dims: Optional[Sequence[str]] = None):
15+
assert isinstance(rv, TensorVariable)
16+
17+
if dims is not None:
18+
dims = [pytensor.as_symbolic(dim) for dim in dims]
19+
assert all(isinstance(dim.type, StringType) for dim in dims)
20+
assert len(dims) == rv.type.ndim
21+
else:
22+
dims = ()
23+
24+
if value is not None:
25+
assert isinstance(value, TensorVariable)
26+
assert rv.type.in_same_class(value.type)
27+
return Apply(self, [rv, value, *dims], [rv.type()])
28+
else:
29+
return Apply(self, [rv, *dims], [rv.type()])
30+
31+
def infer_shape(self, fgraph, node, inputs_shape):
32+
return inputs_shape[0]
33+
34+
def do_constant_folding(self, fgraph, node):
35+
return False
36+
37+
def perform(self, *args, **kwargs):
38+
raise RuntimeError("ValuedRVs should never be evaluated!")
39+
40+
41+
class FreeRV(ModelVar):
42+
pass
43+
44+
45+
class ObservedRV(ModelVar):
46+
pass
47+
48+
49+
class Potential(ModelVar):
50+
pass
51+
52+
53+
class Deterministic(ModelVar):
54+
pass
55+
56+
57+
free_rv = FreeRV()
58+
observed_rv = ObservedRV()
59+
potential = Potential()
60+
deterministic = Deterministic()
61+
62+
63+
def toposort_replace(fgraph: FunctionGraph, replacements) -> None:
64+
toposort = fgraph.toposort()
65+
sorted_replacements = sorted(replacements, key=lambda pair: toposort.index(pair[0].owner))
66+
fgraph.replace_all(tuple(sorted_replacements), import_missing=True)
67+
68+
69+
def fgraph_from_model(model: Model) -> FunctionGraph:
70+
71+
# Collect PyTensor variables
72+
rvs_to_values = model.rvs_to_values
73+
rvs = list(rvs_to_values.keys())
74+
values = list(rvs_to_values.values())
75+
free_rvs = model.free_RVs
76+
deterministics = model.deterministics
77+
potentials = model.potentials
78+
79+
# Collect PyMC meta-info
80+
vars_to_dims = model.named_vars_to_dims
81+
coords = model.coords
82+
83+
# TODO: Do something with these
84+
dim_lengths = model.dim_lengths
85+
rvs_to_transforms = model.rvs_to_transforms
86+
87+
# Not supported yet
88+
if any(v is not None for v in model.rvs_to_total_sizes.values()):
89+
raise NotImplementedError("Cannot convert models with total_sizes")
90+
if any(v is not None for v in model.rvs_to_initial_values.values()):
91+
raise NotImplementedError("Cannot convert models with non-default initial_values")
92+
93+
# We start the `dict` with mappings from the value variables to themselves,
94+
# to prevent them from being cloned.
95+
memo = {v: v for v in values}
96+
97+
fgraph = FunctionGraph(
98+
outputs=rvs + potentials + deterministics,
99+
clone=True,
100+
memo=memo,
101+
copy_orphans=False,
102+
copy_inputs=False,
103+
)
104+
fgraph.coords = coords
105+
106+
# Introduce dummy Ops to label different types of ModelVariables
107+
free_rvs_to_values = {memo[k]: v for k, v in rvs_to_values.items() if k in free_rvs}
108+
observed_rvs_to_values = {memo[k]: v for k, v in rvs_to_values.items() if k not in free_rvs}
109+
potentials = [memo[k] for k in potentials]
110+
deterministics = [memo[k] for k in deterministics]
111+
112+
vars = fgraph.outputs
113+
new_vars = []
114+
for var in vars:
115+
dims = vars_to_dims.get(var.name, None)
116+
if var in free_rvs_to_values:
117+
new_var = free_rv(var, free_rvs_to_values[var], dims)
118+
elif var in observed_rvs_to_values:
119+
new_var = observed_rv(var, observed_rvs_to_values[var], dims)
120+
elif var in potentials:
121+
new_var = potential(var, dims)
122+
elif var in deterministics:
123+
new_var = deterministic(var, dims)
124+
else:
125+
raise RuntimeError(f"Variable is not RV, Potential nor Deterministic: {new_var}")
126+
new_vars.append(new_var)
127+
128+
toposort_replace(fgraph, tuple(zip(vars, new_vars)))
129+
return fgraph
130+
131+
132+
def model_from_fgraph(fgraph: FunctionGraph) -> Model:
133+
model = Model(coords=getattr(fgraph, "coords", None))
134+
135+
fgraph = fgraph.clone()
136+
model_vars_to_vars = {
137+
model_node.outputs[0]: model_node.inputs[0]
138+
for model_node in fgraph.apply_nodes
139+
if isinstance(model_node.op, ModelVar)
140+
}
141+
toposort_replace(fgraph, tuple(model_vars_to_vars.items()))
142+
143+
for model_var in model_vars_to_vars.keys():
144+
if isinstance(model_var.owner.op, FreeRV):
145+
var, value, *dims = model_var.owner.inputs
146+
model.free_RVs.append(var)
147+
model.create_value_var(var, transform=None, value_var=value)
148+
model.set_initval(var, initval=None)
149+
elif isinstance(model_var.owner.op, ObservedRV):
150+
var, value, *dims = model_var.owner.inputs
151+
model.observed_RVs.append(var)
152+
model.create_value_var(var, transform=None, value_var=value)
153+
elif isinstance(model_var.owner.op, Potential):
154+
var, *dims = model_var.owner.inputs
155+
model.potentials.append(var)
156+
elif isinstance(model_var.owner.op, Deterministic):
157+
var, *dims = model_var.owner.inputs
158+
model.deterministics.append(var)
159+
else:
160+
continue # Raise?
161+
162+
if not dims:
163+
dims = None
164+
else:
165+
dims = [dim.data for dim in dims]
166+
model.add_named_variable(var, dims=dims)
167+
168+
return model

pymc_experimental/utils/pytensorf.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytensor
2+
from pytensor.graph import Constant, Type
3+
4+
5+
class StringType(Type[str]):
6+
def clone(self, **kwargs):
7+
return type(self)()
8+
9+
def filter(self, x, strict=False, allow_downcast=None):
10+
if isinstance(x, str):
11+
return x
12+
else:
13+
raise TypeError("Expected a string!")
14+
15+
def __str__(self):
16+
return "string"
17+
18+
@staticmethod
19+
def may_share_memory(a, b):
20+
return isinstance(a, str) and a is b
21+
22+
23+
stringtype = StringType()
24+
25+
26+
class StringConstant(Constant):
27+
pass
28+
29+
30+
@pytensor._as_symbolic.register(str)
31+
def as_symbolic_string(x, **kwargs):
32+
33+
return StringConstant(stringtype, x)

0 commit comments

Comments
 (0)