Skip to content

Commit 8084ef8

Browse files
committed
Do not infer graph_model node types based on variable Op class
1 parent b0d1066 commit 8084ef8

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

pymc/model_graph.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,11 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st
154154
shape = "box"
155155
style = "rounded, filled"
156156
label = f"{var_name}\n~\nMutableData"
157-
elif v.owner and isinstance(v.owner.op, RandomVariable):
157+
elif v in self.model.basic_RVs:
158158
shape = "ellipse"
159-
if hasattr(v.tag, "observations"):
160-
# observed RV
159+
if v in self.model.observed_RVs:
161160
style = "filled"
162161
else:
163-
shape = "ellipse"
164162
style = None
165163
symbol = v.owner.op.__class__.__name__
166164
if symbol.endswith("RV"):

pymc/tests/test_model_graph.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,36 @@ def model_observation_dtype_casting():
232232
return model, compute_graph, plates
233233

234234

235+
def model_non_random_variable_rvs():
236+
"""Test that node types are not inferred based on the variable Op type, but
237+
model properties
238+
239+
See https://github.com/pymc-devs/pymc/issues/5766
240+
"""
241+
with pm.Model() as model:
242+
mu = pm.Normal(name="mu", mu=0.0, sigma=5.0)
243+
244+
y_raw = pm.Normal.dist(mu)
245+
y = pm.math.clip(y_raw, -3, 3)
246+
model.register_rv(y, name="y")
247+
248+
z_raw = pm.Normal.dist(y, shape=(5,))
249+
z = pm.math.clip(z_raw, -1, 1)
250+
model.register_rv(z, name="z", data=[0] * 5)
251+
252+
compute_graph = {
253+
"mu": set(),
254+
"y": {"mu"},
255+
"z": {"y"},
256+
}
257+
plates = {
258+
"": {"mu", "y"},
259+
"5": {"z"},
260+
}
261+
262+
return model, compute_graph, plates
263+
264+
235265
class BaseModelGraphTest(SeededTest):
236266
model_func = None
237267

@@ -360,3 +390,7 @@ def test_subgraph(self, var_names, vars_to_plot, compute_graph):
360390
mg = ModelGraph(model_with_different_descendants())
361391
assert set(mg.vars_to_plot(var_names=var_names)) == set(vars_to_plot)
362392
assert mg.make_compute_graph(var_names=var_names) == compute_graph
393+
394+
395+
class TestModelNonRandomVariableRVs(BaseModelGraphTest):
396+
model_func = model_non_random_variable_rvs

0 commit comments

Comments
 (0)