|
22 | 22 |
|
23 | 23 | import pymc as pm
|
24 | 24 |
|
| 25 | +from pymc.distributions import ( |
| 26 | + Cauchy, |
| 27 | + Censored, |
| 28 | + GaussianRandomWalk, |
| 29 | + Mixture, |
| 30 | + Normal, |
| 31 | + StudentT, |
| 32 | + Truncated, |
| 33 | + ZeroInflatedPoisson, |
| 34 | +) |
25 | 35 | from pymc.exceptions import ImputationWarning
|
26 | 36 | from pymc.model_graph import ModelGraph, model_to_graphviz, model_to_networkx
|
27 | 37 | from pymc.tests.helpers import SeededTest
|
@@ -360,3 +370,42 @@ def test_subgraph(self, var_names, vars_to_plot, compute_graph):
|
360 | 370 | mg = ModelGraph(model_with_different_descendants())
|
361 | 371 | assert set(mg.vars_to_plot(var_names=var_names)) == set(vars_to_plot)
|
362 | 372 | assert mg.make_compute_graph(var_names=var_names) == compute_graph
|
| 373 | + |
| 374 | + |
| 375 | +@pytest.mark.parametrize( |
| 376 | + "symbolic_dist, dist_kwargs, display_name", |
| 377 | + [ |
| 378 | + (ZeroInflatedPoisson, {"psi": 0.5, "mu": 5}, "ZeroInflatedPoisson"), |
| 379 | + ( |
| 380 | + Censored, |
| 381 | + {"dist": Normal.dist(Normal.dist(0.0, 5.0), 2.0), "lower": -2, "upper": 2}, |
| 382 | + "CensoredNormal", |
| 383 | + ), |
| 384 | + ( |
| 385 | + Mixture, |
| 386 | + {"w": [0.5, 0.5], "comp_dists": Normal.dist(0.0, 5.0, shape=(2,))}, |
| 387 | + "NormalMixture", |
| 388 | + ), |
| 389 | + ( |
| 390 | + Mixture, |
| 391 | + {"w": [0.5, 0.5], "comp_dists": [Normal.dist(0.0, 5.0), StudentT.dist(7.0)]}, |
| 392 | + "Normal-StudentTMixture", |
| 393 | + ), |
| 394 | + ( |
| 395 | + Mixture, |
| 396 | + { |
| 397 | + "w": [0.3, 0.45, 0.25], |
| 398 | + "comp_dists": [Normal.dist(0.0, 5.0), StudentT.dist(7.0), Cauchy.dist(1.0, 1.0)], |
| 399 | + }, |
| 400 | + "MarginalMixture", |
| 401 | + ), |
| 402 | + (GaussianRandomWalk, {"init_dist": Normal.dist(0.0, 5.0), "steps": 10}, "RandomWalk"), |
| 403 | + (Truncated, {"dist": StudentT.dist(7), "upper": 3.0}, "TruncatedStudentT"), |
| 404 | + ], |
| 405 | +) |
| 406 | +def test_symbolic_distribution_display(symbolic_dist, dist_kwargs, display_name): |
| 407 | + with pm.Model() as model: |
| 408 | + symbolic_dist("x", **dist_kwargs) |
| 409 | + |
| 410 | + graph = model_to_graphviz(model) |
| 411 | + assert display_name in graph.source |
0 commit comments