Skip to content

Commit c88664c

Browse files
Add graphviz tests for the display of symbolic distributions
1 parent f481347 commit c88664c

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

pymc/tests/test_model_graph.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@
2222

2323
import pymc as pm
2424

25+
from pymc.distributions import (
26+
Cauchy,
27+
Censored,
28+
GaussianRandomWalk,
29+
Mixture,
30+
Normal,
31+
StudentT,
32+
Truncated,
33+
ZeroInflatedPoisson,
34+
)
2535
from pymc.exceptions import ImputationWarning
2636
from pymc.model_graph import ModelGraph, model_to_graphviz, model_to_networkx
2737
from pymc.tests.helpers import SeededTest
@@ -360,3 +370,42 @@ def test_subgraph(self, var_names, vars_to_plot, compute_graph):
360370
mg = ModelGraph(model_with_different_descendants())
361371
assert set(mg.vars_to_plot(var_names=var_names)) == set(vars_to_plot)
362372
assert mg.make_compute_graph(var_names=var_names) == compute_graph
373+
374+
375+
@pytest.mark.parametrize(
376+
"symbolic_dist, dist_kwargs, display_name",
377+
[
378+
(ZeroInflatedPoisson, {"psi": 0.5, "mu": 5}, "ZeroInflatedPoisson"),
379+
(
380+
Censored,
381+
{"dist": Normal.dist(Normal.dist(0.0, 5.0), 2.0), "lower": -2, "upper": 2},
382+
"CensoredNormal",
383+
),
384+
(
385+
Mixture,
386+
{"w": [0.5, 0.5], "comp_dists": Normal.dist(0.0, 5.0, shape=(2,))},
387+
"NormalMixture",
388+
),
389+
(
390+
Mixture,
391+
{"w": [0.5, 0.5], "comp_dists": [Normal.dist(0.0, 5.0), StudentT.dist(7.0)]},
392+
"Normal-StudentTMixture",
393+
),
394+
(
395+
Mixture,
396+
{
397+
"w": [0.3, 0.45, 0.25],
398+
"comp_dists": [Normal.dist(0.0, 5.0), StudentT.dist(7.0), Cauchy.dist(1.0, 1.0)],
399+
},
400+
"MarginalMixture",
401+
),
402+
(GaussianRandomWalk, {"init_dist": Normal.dist(0.0, 5.0), "steps": 10}, "RandomWalk"),
403+
(Truncated, {"dist": StudentT.dist(7), "upper": 3.0}, "TruncatedStudentT"),
404+
],
405+
)
406+
def test_symbolic_distribution_display(symbolic_dist, dist_kwargs, display_name):
407+
with pm.Model() as model:
408+
symbolic_dist("x", **dist_kwargs)
409+
410+
graph = model_to_graphviz(model)
411+
assert display_name in graph.source

0 commit comments

Comments
 (0)