From f48134722f01129f6fe3524ba5bca947ff13ecc4 Mon Sep 17 00:00:00 2001 From: Larry Dong Date: Mon, 26 Sep 2022 17:17:37 -0400 Subject: [PATCH 1/3] Add graphviz support for SymbolicRVs --- pymc/model_graph.py | 52 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 93f8c3a9af..6b245f4a92 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -21,17 +21,37 @@ from aesara.graph import Apply from aesara.graph.basic import ancestors, walk from aesara.scalar.basic import Cast +from aesara.tensor.basic import get_scalar_constant_value from aesara.tensor.elemwise import Elemwise from aesara.tensor.random.op import RandomVariable from aesara.tensor.var import TensorConstant, TensorVariable import pymc as pm +from pymc.distributions import Discrete +from pymc.distributions.discrete import DiracDelta from pymc.util import get_default_varnames, get_var_name VarName = NewType("VarName", str) +def check_zip_graph_from_components(components): + """ + This helper function checks if a mixture sub-graph corresponds to a + zero-inflated distribution using its components, a list of length two. + """ + if not any(isinstance(var.owner.op, DiracDelta) for var in components): + return False + + dirac_delta_idx = 1 - int(isinstance(components[0].owner.op, DiracDelta)) + dirac_delta = components[dirac_delta_idx] + other_comp = components[1 - dirac_delta_idx] + + return (get_scalar_constant_value(dirac_delta.owner.inputs[3]) == 0) and isinstance( + other_comp.owner.op, Discrete + ) + + class ModelGraph: def __init__(self, model): self.model = model @@ -154,16 +174,40 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st shape = "box" style = "rounded, filled" label = f"{var_name}\n~\nMutableData" - elif v.owner and isinstance(v.owner.op, RandomVariable): + elif v.owner and (v in self.model.basic_RVs): shape = "ellipse" - if hasattr(v.tag, "observations"): + if v in self.model.observed_RVs: # observed RV style = "filled" else: - shape = "ellipse" style = None symbol = v.owner.op.__class__.__name__ - if symbol.endswith("RV"): + if symbol == "MarginalMixtureRV": + components = v.owner.inputs[2:] + if len(components) == 2: + component_names = [ + var.owner.op.__class__.__name__.replace("Unmeasurable", "")[:-2] + for var in components + ] + if check_zip_graph_from_components(components): + # ZeroInflated distribution + component_names.remove("DiracDelta") + symbol = f"ZeroInflated{component_names[0]}" + else: + # X-Y mixture + symbol = f"{'-'.join(component_names)}Mixture" + elif len(components) == 1: + # single component dispatch mixture + symbol = f"{components[0].owner.op.__class__.__name__.replace('Unmeasurable', '')[:-2]}Mixture" + else: + symbol = symbol[:-2] # just MarginalMixture + elif symbol == "CensoredRV": + censored_dist = v.owner.inputs[0] + symbol = symbol[:-2] + censored_dist.owner.op.__class__.__name__[:-2] + elif symbol == "TruncatedRV": + truncated_dist = v.owner.op.base_rv_op + symbol = symbol[:-2] + truncated_dist.__class__.__name__[:-2] + elif symbol.endswith("RV"): symbol = symbol[:-2] label = f"{var_name}\n~\n{symbol}" else: From c88664c0156df721efd12588adc708064750d65a Mon Sep 17 00:00:00 2001 From: Larry Dong Date: Mon, 3 Oct 2022 12:45:35 -0400 Subject: [PATCH 2/3] Add graphviz tests for the display of symbolic distributions --- pymc/tests/test_model_graph.py | 49 ++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/pymc/tests/test_model_graph.py b/pymc/tests/test_model_graph.py index 1df3588465..c1b483a1cd 100644 --- a/pymc/tests/test_model_graph.py +++ b/pymc/tests/test_model_graph.py @@ -22,6 +22,16 @@ import pymc as pm +from pymc.distributions import ( + Cauchy, + Censored, + GaussianRandomWalk, + Mixture, + Normal, + StudentT, + Truncated, + ZeroInflatedPoisson, +) from pymc.exceptions import ImputationWarning from pymc.model_graph import ModelGraph, model_to_graphviz, model_to_networkx from pymc.tests.helpers import SeededTest @@ -360,3 +370,42 @@ def test_subgraph(self, var_names, vars_to_plot, compute_graph): mg = ModelGraph(model_with_different_descendants()) assert set(mg.vars_to_plot(var_names=var_names)) == set(vars_to_plot) assert mg.make_compute_graph(var_names=var_names) == compute_graph + + +@pytest.mark.parametrize( + "symbolic_dist, dist_kwargs, display_name", + [ + (ZeroInflatedPoisson, {"psi": 0.5, "mu": 5}, "ZeroInflatedPoisson"), + ( + Censored, + {"dist": Normal.dist(Normal.dist(0.0, 5.0), 2.0), "lower": -2, "upper": 2}, + "CensoredNormal", + ), + ( + Mixture, + {"w": [0.5, 0.5], "comp_dists": Normal.dist(0.0, 5.0, shape=(2,))}, + "NormalMixture", + ), + ( + Mixture, + {"w": [0.5, 0.5], "comp_dists": [Normal.dist(0.0, 5.0), StudentT.dist(7.0)]}, + "Normal-StudentTMixture", + ), + ( + Mixture, + { + "w": [0.3, 0.45, 0.25], + "comp_dists": [Normal.dist(0.0, 5.0), StudentT.dist(7.0), Cauchy.dist(1.0, 1.0)], + }, + "MarginalMixture", + ), + (GaussianRandomWalk, {"init_dist": Normal.dist(0.0, 5.0), "steps": 10}, "RandomWalk"), + (Truncated, {"dist": StudentT.dist(7), "upper": 3.0}, "TruncatedStudentT"), + ], +) +def test_symbolic_distribution_display(symbolic_dist, dist_kwargs, display_name): + with pm.Model() as model: + symbolic_dist("x", **dist_kwargs) + + graph = model_to_graphviz(model) + assert display_name in graph.source From 20cfc989d38edbbbb8fe765770a109f9e27f8541 Mon Sep 17 00:00:00 2001 From: Larry Dong Date: Mon, 24 Oct 2022 12:48:56 -0400 Subject: [PATCH 3/3] Use refactored _print_name to display SymbolicRVs --- pymc/model_graph.py | 27 ++++++++++++++------------- pymc/tests/test_model_graph.py | 16 +++++++++++++++- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 6b245f4a92..00ebd4aa6c 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -181,14 +181,11 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st style = "filled" else: style = None - symbol = v.owner.op.__class__.__name__ - if symbol == "MarginalMixtureRV": + symbol = v.owner.op._print_name[0] + if symbol == "MarginalMixture": components = v.owner.inputs[2:] if len(components) == 2: - component_names = [ - var.owner.op.__class__.__name__.replace("Unmeasurable", "")[:-2] - for var in components - ] + component_names = [var.owner.op._print_name[0] for var in components] if check_zip_graph_from_components(components): # ZeroInflated distribution component_names.remove("DiracDelta") @@ -198,17 +195,21 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st symbol = f"{'-'.join(component_names)}Mixture" elif len(components) == 1: # single component dispatch mixture - symbol = f"{components[0].owner.op.__class__.__name__.replace('Unmeasurable', '')[:-2]}Mixture" + symbol = f"{components[0].owner.op._print_name[0]}Mixture" else: symbol = symbol[:-2] # just MarginalMixture - elif symbol == "CensoredRV": + elif symbol == "Censored": censored_dist = v.owner.inputs[0] - symbol = symbol[:-2] + censored_dist.owner.op.__class__.__name__[:-2] - elif symbol == "TruncatedRV": + symbol = symbol + censored_dist.owner.op._print_name[0] + elif symbol == "Truncated": truncated_dist = v.owner.op.base_rv_op - symbol = symbol[:-2] + truncated_dist.__class__.__name__[:-2] - elif symbol.endswith("RV"): - symbol = symbol[:-2] + symbol = symbol + truncated_dist._print_name[0] + elif symbol == "RandomWalk": + innovation_dist = v.owner.inputs[1].owner.op._print_name[0] + if innovation_dist == "Normal": + symbol = "Gaussian" + symbol + else: + symbol = innovation_dist + symbol label = f"{var_name}\n~\n{symbol}" else: shape = "box" diff --git a/pymc/tests/test_model_graph.py b/pymc/tests/test_model_graph.py index c1b483a1cd..91db13610f 100644 --- a/pymc/tests/test_model_graph.py +++ b/pymc/tests/test_model_graph.py @@ -28,6 +28,7 @@ GaussianRandomWalk, Mixture, Normal, + RandomWalk, StudentT, Truncated, ZeroInflatedPoisson, @@ -399,8 +400,21 @@ def test_subgraph(self, var_names, vars_to_plot, compute_graph): }, "MarginalMixture", ), - (GaussianRandomWalk, {"init_dist": Normal.dist(0.0, 5.0), "steps": 10}, "RandomWalk"), + ( + GaussianRandomWalk, + {"init_dist": Normal.dist(0.0, 5.0), "steps": 10}, + "GaussianRandomWalk", + ), (Truncated, {"dist": StudentT.dist(7), "upper": 3.0}, "TruncatedStudentT"), + ( + RandomWalk, + { + "innovation_dist": pm.StudentT.dist(7), + "init_dist": pm.Normal.dist(0, 1), + "steps": 10, + }, + "StudentTRandomWalk", + ), ], ) def test_symbolic_distribution_display(symbolic_dist, dist_kwargs, display_name):