|
| 1 | +import numpy as np |
| 2 | +import pymc as pm |
| 3 | +import pytensor.tensor as pt |
| 4 | +import pytest |
| 5 | +from pytensor.graph import Constant, FunctionGraph, node_rewriter |
| 6 | +from pytensor.graph.rewriting.basic import in2out |
| 7 | +from pytensor.tensor.exceptions import NotScalarConstantError |
| 8 | + |
| 9 | +from pymc_experimental.utils.model_fgraph import ( |
| 10 | + ModelFreeRV, |
| 11 | + ModelVar, |
| 12 | + fgraph_from_model, |
| 13 | + model_deterministic, |
| 14 | + model_free_rv, |
| 15 | + model_from_fgraph, |
| 16 | +) |
| 17 | + |
| 18 | + |
| 19 | +def test_basic(): |
| 20 | + """Test we can convert from a PyMC Model to a FunctionGraph and back""" |
| 21 | + with pm.Model(coords={"test_dim": range(3)}) as m_old: |
| 22 | + x = pm.Normal("x") |
| 23 | + y = pm.Deterministic("y", x + 1) |
| 24 | + w = pm.HalfNormal("w", pm.math.exp(y)) |
| 25 | + z = pm.Normal("z", y, w, observed=[0, 1, 2], dims=("test_dim",)) |
| 26 | + pm.Potential("pot", x * 2) |
| 27 | + |
| 28 | + m_fgraph = fgraph_from_model(m_old) |
| 29 | + assert isinstance(m_fgraph, FunctionGraph) |
| 30 | + |
| 31 | + m_new = model_from_fgraph(m_fgraph) |
| 32 | + assert isinstance(m_new, pm.Model) |
| 33 | + |
| 34 | + assert m_new.coords == {"test_dim": tuple(range(3))} |
| 35 | + assert m_new._dim_lengths["test_dim"].eval() == 3 |
| 36 | + assert m_new.named_vars_to_dims == {"z": ["test_dim"]} |
| 37 | + |
| 38 | + named_vars = {"x", "y", "w", "z", "pot"} |
| 39 | + assert set(m_new.named_vars) == named_vars |
| 40 | + for named_var in named_vars: |
| 41 | + assert m_new[named_var] is not m_old[named_var] |
| 42 | + for value_new, value_old in zip(m_new.rvs_to_values.values(), m_old.rvs_to_values.values()): |
| 43 | + # Constants are not cloned |
| 44 | + if not isinstance(value_new, Constant): |
| 45 | + assert value_new is not value_old |
| 46 | + assert m_new["x"] in m_new.free_RVs |
| 47 | + assert m_new["w"] in m_new.free_RVs |
| 48 | + assert m_new["y"] in m_new.deterministics |
| 49 | + assert m_new["z"] in m_new.observed_RVs |
| 50 | + assert m_new["pot"] in m_new.potentials |
| 51 | + assert m_new.rvs_to_transforms[m_new["x"]] is None |
| 52 | + assert m_new.rvs_to_transforms[m_new["w"]] is pm.distributions.transforms.log |
| 53 | + assert m_new.rvs_to_transforms[m_new["z"]] is None |
| 54 | + |
| 55 | + # Test random |
| 56 | + new_y_draw, new_z_draw = pm.draw([m_new["y"], m_new["z"]], draws=5, random_seed=1) |
| 57 | + old_y_draw, old_z_draw = pm.draw([m_old["y"], m_old["z"]], draws=5, random_seed=1) |
| 58 | + np.testing.assert_array_equal(new_y_draw, old_y_draw) |
| 59 | + np.testing.assert_array_equal(new_z_draw, old_z_draw) |
| 60 | + |
| 61 | + # Test logp |
| 62 | + ip = m_new.initial_point() |
| 63 | + np.testing.assert_equal( |
| 64 | + m_new.compile_logp()(ip), |
| 65 | + m_old.compile_logp()(ip), |
| 66 | + ) |
| 67 | + |
| 68 | + |
| 69 | +def test_data(): |
| 70 | + """Test shared RNGs, MutableData, ConstantData and Dim lengths are handled correctly. |
| 71 | +
|
| 72 | + Everything should be preserved across new and old models, except for shared RNGs |
| 73 | + """ |
| 74 | + with pm.Model(coords_mutable={"test_dim": range(3)}) as m_old: |
| 75 | + x = pm.MutableData("x", [0.0, 1.0, 2.0], dims=("test_dim",)) |
| 76 | + y = pm.MutableData("y", [10.0, 11.0, 12.0], dims=("test_dim",)) |
| 77 | + b0 = pm.ConstantData("b0", 0.0) |
| 78 | + b1 = pm.Normal("b1") |
| 79 | + mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",)) |
| 80 | + obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",)) |
| 81 | + |
| 82 | + m_new = model_from_fgraph(fgraph_from_model(m_old)) |
| 83 | + |
| 84 | + # ConstantData is preserved |
| 85 | + assert m_new["b0"].data == m_old["b0"].data |
| 86 | + |
| 87 | + # Shared non-rng shared variables are preserved |
| 88 | + assert m_new["x"].container is x.container |
| 89 | + assert m_new["y"].container is y.container |
| 90 | + assert m_new.rvs_to_values[m_new["obs"]] is m_new["y"] |
| 91 | + |
| 92 | + # Shared rng shared variables are not preserved |
| 93 | + m_new["b1"].owner.inputs[0].container is not m_old["b1"].owner.inputs[0].container |
| 94 | + |
| 95 | + with m_old: |
| 96 | + pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)}) |
| 97 | + |
| 98 | + assert m_new.dim_lengths["test_dim"].eval() == 2 |
| 99 | + np.testing.assert_array_almost_equal(pm.draw(m_new["x"]), [100.0, 200.0]) |
| 100 | + |
| 101 | + |
| 102 | +def test_deterministics(): |
| 103 | + """Test handling of deterministics. |
| 104 | +
|
| 105 | + We don't want Deterministics in the middle of the FunctionGraph, as they would make rewrites cumbersome |
| 106 | + However we want them in the middle of Model.basic_RVs, so they display nicely in graphviz |
| 107 | +
|
| 108 | + There is one edge case that has to be considered, when a Deterministic is just a copy of a RV. |
| 109 | + In that case we don't bother to reintroduce it in between other Model.basic_RVs |
| 110 | + """ |
| 111 | + with pm.Model() as m: |
| 112 | + x = pm.Normal("x") |
| 113 | + mu = pm.Deterministic("mu", pm.math.abs(x)) |
| 114 | + sigma = pm.math.exp(x) |
| 115 | + pm.Deterministic("sigma", sigma) |
| 116 | + y = pm.Normal("y", mu, sigma) |
| 117 | + # Special case where the Deterministic |
| 118 | + # is a direct view on another model variable |
| 119 | + y_ = pm.Deterministic("y_", y) |
| 120 | + # Just for kicks, make it a double one! |
| 121 | + y__ = pm.Deterministic("y__", y_) |
| 122 | + z = pm.Normal("z", y__) |
| 123 | + |
| 124 | + # Deterministic mu is in the graph of x to y but not sigma |
| 125 | + assert m["y"].owner.inputs[3] is m["mu"] |
| 126 | + assert m["y"].owner.inputs[4] is not m["sigma"] |
| 127 | + |
| 128 | + fg = fgraph_from_model(m) |
| 129 | + |
| 130 | + # Check that no Deterministics are in graph of x to y and y to z |
| 131 | + x, y, z, det_mu, det_sigma, det_y_, det_y__ = fg.outputs |
| 132 | + # [Det(mu), Det(sigma)] |
| 133 | + mu = det_mu.owner.inputs[0] |
| 134 | + sigma = det_sigma.owner.inputs[0] |
| 135 | + # [FreeRV(y(mu, sigma))] not [FreeRV(y(Det(mu), Det(sigma)))] |
| 136 | + assert y.owner.inputs[0].owner.inputs[3] is mu |
| 137 | + assert y.owner.inputs[0].owner.inputs[4] is sigma |
| 138 | + # [FreeRV(z(y))] not [FreeRV(z(Det(Det(y))))] |
| 139 | + assert z.owner.inputs[0].owner.inputs[3] is y |
| 140 | + # [Det(y), Det(y)], not [Det(y), Det(Det(y))] |
| 141 | + assert det_y_.owner.inputs[0] is y |
| 142 | + assert det_y__.owner.inputs[0] is y |
| 143 | + assert det_y_ is not det_y__ |
| 144 | + |
| 145 | + # Both mu and sigma deterministics are now in the graph of x to y |
| 146 | + m = model_from_fgraph(fg) |
| 147 | + assert m["y"].owner.inputs[3] is m["mu"] |
| 148 | + assert m["y"].owner.inputs[4] is m["sigma"] |
| 149 | + # But not y_* in y to z, since there was no real Op in between |
| 150 | + assert m["z"].owner.inputs[3] is m["y"] |
| 151 | + assert m["y_"].owner.inputs[0] is m["y"] |
| 152 | + assert m["y__"].owner.inputs[0] is m["y"] |
| 153 | + |
| 154 | + |
| 155 | +def test_context_error(): |
| 156 | + """Test that model_from_fgraph fails when called inside a Model context. |
| 157 | +
|
| 158 | + We can't allow it, because the new Model that's returned would be a child of whatever Model context is active. |
| 159 | + """ |
| 160 | + with pm.Model() as m: |
| 161 | + x = pm.Normal("x") |
| 162 | + |
| 163 | + fg = fgraph_from_model(m) |
| 164 | + |
| 165 | + with pytest.raises(RuntimeError, match="cannot be called inside a PyMC model context"): |
| 166 | + model_from_fgraph(fg) |
| 167 | + |
| 168 | + |
| 169 | +def test_sub_model_error(): |
| 170 | + """Test Error is raised when trying to convert a sub-model to fgraph.""" |
| 171 | + with pm.Model() as m: |
| 172 | + x = pm.Beta("x", 1, 1) |
| 173 | + with pm.Model() as sub_m: |
| 174 | + y = pm.Normal("y", x) |
| 175 | + |
| 176 | + nodes = [v for v in fgraph_from_model(m).toposort() if not isinstance(v.op, ModelVar)] |
| 177 | + assert len(nodes) == 2 |
| 178 | + assert isinstance(nodes[0].op, pm.Beta) |
| 179 | + assert isinstance(nodes[1].op, pm.Normal) |
| 180 | + |
| 181 | + with pytest.raises(ValueError, match="Nested sub-models cannot be converted"): |
| 182 | + fgraph_from_model(sub_m) |
| 183 | + |
| 184 | + |
| 185 | +@pytest.fixture() |
| 186 | +def non_centered_rewrite(): |
| 187 | + @node_rewriter(tracks=[ModelFreeRV]) |
| 188 | + def non_centered_param(fgraph: FunctionGraph, node): |
| 189 | + """Rewrite that replaces centered normal by non-centered parametrization.""" |
| 190 | + |
| 191 | + rv, value, *dims = node.inputs |
| 192 | + if not isinstance(rv.owner.op, pm.Normal): |
| 193 | + return |
| 194 | + rng, size, dtype, loc, scale = rv.owner.inputs |
| 195 | + |
| 196 | + # Only apply rewrite if size information is explicit |
| 197 | + if size.ndim == 0: |
| 198 | + return None |
| 199 | + |
| 200 | + try: |
| 201 | + is_unit = ( |
| 202 | + pt.get_underlying_scalar_constant_value(loc) == 0 |
| 203 | + and pt.get_underlying_scalar_constant_value(scale) == 1 |
| 204 | + ) |
| 205 | + except NotScalarConstantError: |
| 206 | + is_unit = False |
| 207 | + |
| 208 | + # Nothing to do here |
| 209 | + if is_unit: |
| 210 | + return |
| 211 | + |
| 212 | + raw_norm = pm.Normal.dist(0, 1, size=size, rng=rng) |
| 213 | + raw_norm.name = f"{rv.name}_raw_" |
| 214 | + raw_norm_value = raw_norm.clone() |
| 215 | + fgraph.add_input(raw_norm_value) |
| 216 | + raw_norm = model_free_rv(raw_norm, raw_norm_value, node.op.transform, *dims) |
| 217 | + |
| 218 | + new_norm = loc + raw_norm * scale |
| 219 | + new_norm.name = rv.name |
| 220 | + new_norm_det = model_deterministic(new_norm, *dims) |
| 221 | + fgraph.add_output(new_norm_det) |
| 222 | + |
| 223 | + return [new_norm] |
| 224 | + |
| 225 | + return in2out(non_centered_param) |
| 226 | + |
| 227 | + |
| 228 | +def test_fgraph_rewrite(non_centered_rewrite): |
| 229 | + """Test we can apply a simple rewrite to a PyMC Model.""" |
| 230 | + |
| 231 | + with pm.Model(coords={"subject": range(10)}) as m_old: |
| 232 | + group_mean = pm.Normal("group_mean") |
| 233 | + group_std = pm.HalfNormal("group_std") |
| 234 | + subject_mean = pm.Normal("subject_mean", group_mean, group_std, dims=("subject",)) |
| 235 | + obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10), dims=("subject",)) |
| 236 | + |
| 237 | + fg = fgraph_from_model(m_old) |
| 238 | + non_centered_rewrite.apply(fg) |
| 239 | + |
| 240 | + m_new = model_from_fgraph(fg) |
| 241 | + assert m_new.named_vars_to_dims == { |
| 242 | + "subject_mean": ["subject"], |
| 243 | + "subject_mean_raw_": ["subject"], |
| 244 | + "obs": ["subject"], |
| 245 | + } |
| 246 | + assert set(m_new.named_vars) == { |
| 247 | + "group_mean", |
| 248 | + "group_std", |
| 249 | + "subject_mean_raw_", |
| 250 | + "subject_mean", |
| 251 | + "obs", |
| 252 | + } |
| 253 | + assert {rv.name for rv in m_new.free_RVs} == {"group_mean", "group_std", "subject_mean_raw_"} |
| 254 | + assert {rv.name for rv in m_new.observed_RVs} == {"obs"} |
| 255 | + assert {rv.name for rv in m_new.deterministics} == {"subject_mean"} |
| 256 | + |
| 257 | + with pm.Model() as m_ref: |
| 258 | + group_mean = pm.Normal("group_mean") |
| 259 | + group_std = pm.HalfNormal("group_std") |
| 260 | + subject_mean_raw = pm.Normal("subject_mean_raw_", 0, 1, shape=(10,)) |
| 261 | + subject_mean = pm.Deterministic("subject_mean", group_mean + subject_mean_raw * group_std) |
| 262 | + obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10)) |
| 263 | + |
| 264 | + np.testing.assert_array_equal( |
| 265 | + pm.draw(m_new["subject_mean_raw_"], draws=7, random_seed=1), |
| 266 | + pm.draw(m_ref["subject_mean_raw_"], draws=7, random_seed=1), |
| 267 | + ) |
| 268 | + |
| 269 | + ip = m_new.initial_point() |
| 270 | + np.testing.assert_equal( |
| 271 | + m_new.compile_logp()(ip), |
| 272 | + m_ref.compile_logp()(ip), |
| 273 | + ) |
0 commit comments