Skip to content

Commit 20cfc98

Browse files
Use refactored _print_name to display SymbolicRVs
1 parent c88664c commit 20cfc98

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

pymc/model_graph.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,11 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st
181181
style = "filled"
182182
else:
183183
style = None
184-
symbol = v.owner.op.__class__.__name__
185-
if symbol == "MarginalMixtureRV":
184+
symbol = v.owner.op._print_name[0]
185+
if symbol == "MarginalMixture":
186186
components = v.owner.inputs[2:]
187187
if len(components) == 2:
188-
component_names = [
189-
var.owner.op.__class__.__name__.replace("Unmeasurable", "")[:-2]
190-
for var in components
191-
]
188+
component_names = [var.owner.op._print_name[0] for var in components]
192189
if check_zip_graph_from_components(components):
193190
# ZeroInflated distribution
194191
component_names.remove("DiracDelta")
@@ -198,17 +195,21 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st
198195
symbol = f"{'-'.join(component_names)}Mixture"
199196
elif len(components) == 1:
200197
# single component dispatch mixture
201-
symbol = f"{components[0].owner.op.__class__.__name__.replace('Unmeasurable', '')[:-2]}Mixture"
198+
symbol = f"{components[0].owner.op._print_name[0]}Mixture"
202199
else:
203200
symbol = symbol[:-2] # just MarginalMixture
204-
elif symbol == "CensoredRV":
201+
elif symbol == "Censored":
205202
censored_dist = v.owner.inputs[0]
206-
symbol = symbol[:-2] + censored_dist.owner.op.__class__.__name__[:-2]
207-
elif symbol == "TruncatedRV":
203+
symbol = symbol + censored_dist.owner.op._print_name[0]
204+
elif symbol == "Truncated":
208205
truncated_dist = v.owner.op.base_rv_op
209-
symbol = symbol[:-2] + truncated_dist.__class__.__name__[:-2]
210-
elif symbol.endswith("RV"):
211-
symbol = symbol[:-2]
206+
symbol = symbol + truncated_dist._print_name[0]
207+
elif symbol == "RandomWalk":
208+
innovation_dist = v.owner.inputs[1].owner.op._print_name[0]
209+
if innovation_dist == "Normal":
210+
symbol = "Gaussian" + symbol
211+
else:
212+
symbol = innovation_dist + symbol
212213
label = f"{var_name}\n~\n{symbol}"
213214
else:
214215
shape = "box"

pymc/tests/test_model_graph.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
GaussianRandomWalk,
2929
Mixture,
3030
Normal,
31+
RandomWalk,
3132
StudentT,
3233
Truncated,
3334
ZeroInflatedPoisson,
@@ -399,8 +400,21 @@ def test_subgraph(self, var_names, vars_to_plot, compute_graph):
399400
},
400401
"MarginalMixture",
401402
),
402-
(GaussianRandomWalk, {"init_dist": Normal.dist(0.0, 5.0), "steps": 10}, "RandomWalk"),
403+
(
404+
GaussianRandomWalk,
405+
{"init_dist": Normal.dist(0.0, 5.0), "steps": 10},
406+
"GaussianRandomWalk",
407+
),
403408
(Truncated, {"dist": StudentT.dist(7), "upper": 3.0}, "TruncatedStudentT"),
409+
(
410+
RandomWalk,
411+
{
412+
"innovation_dist": pm.StudentT.dist(7),
413+
"init_dist": pm.Normal.dist(0, 1),
414+
"steps": 10,
415+
},
416+
"StudentTRandomWalk",
417+
),
404418
],
405419
)
406420
def test_symbolic_distribution_display(symbolic_dist, dist_kwargs, display_name):

0 commit comments

Comments
 (0)