diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 93f8c3a9af..00ebd4aa6c 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -21,17 +21,37 @@ from aesara.graph import Apply from aesara.graph.basic import ancestors, walk from aesara.scalar.basic import Cast +from aesara.tensor.basic import get_scalar_constant_value from aesara.tensor.elemwise import Elemwise from aesara.tensor.random.op import RandomVariable from aesara.tensor.var import TensorConstant, TensorVariable import pymc as pm +from pymc.distributions import Discrete +from pymc.distributions.discrete import DiracDelta from pymc.util import get_default_varnames, get_var_name VarName = NewType("VarName", str) +def check_zip_graph_from_components(components): + """ + This helper function checks if a mixture sub-graph corresponds to a + zero-inflated distribution using its components, a list of length two. + """ + if not any(isinstance(var.owner.op, DiracDelta) for var in components): + return False + + dirac_delta_idx = 1 - int(isinstance(components[0].owner.op, DiracDelta)) + dirac_delta = components[dirac_delta_idx] + other_comp = components[1 - dirac_delta_idx] + + return (get_scalar_constant_value(dirac_delta.owner.inputs[3]) == 0) and isinstance( + other_comp.owner.op, Discrete + ) + + class ModelGraph: def __init__(self, model): self.model = model @@ -154,17 +174,42 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st shape = "box" style = "rounded, filled" label = f"{var_name}\n~\nMutableData" - elif v.owner and isinstance(v.owner.op, RandomVariable): + elif v.owner and (v in self.model.basic_RVs): shape = "ellipse" - if hasattr(v.tag, "observations"): + if v in self.model.observed_RVs: # observed RV style = "filled" else: - shape = "ellipse" style = None - symbol = v.owner.op.__class__.__name__ - if symbol.endswith("RV"): - symbol = symbol[:-2] + symbol = v.owner.op._print_name[0] + if symbol == "MarginalMixture": + components = v.owner.inputs[2:] + if len(components) == 2: + component_names = [var.owner.op._print_name[0] for var in components] + if check_zip_graph_from_components(components): + # ZeroInflated distribution + component_names.remove("DiracDelta") + symbol = f"ZeroInflated{component_names[0]}" + else: + # X-Y mixture + symbol = f"{'-'.join(component_names)}Mixture" + elif len(components) == 1: + # single component dispatch mixture + symbol = f"{components[0].owner.op._print_name[0]}Mixture" + else: + symbol = symbol[:-2] # just MarginalMixture + elif symbol == "Censored": + censored_dist = v.owner.inputs[0] + symbol = symbol + censored_dist.owner.op._print_name[0] + elif symbol == "Truncated": + truncated_dist = v.owner.op.base_rv_op + symbol = symbol + truncated_dist._print_name[0] + elif symbol == "RandomWalk": + innovation_dist = v.owner.inputs[1].owner.op._print_name[0] + if innovation_dist == "Normal": + symbol = "Gaussian" + symbol + else: + symbol = innovation_dist + symbol label = f"{var_name}\n~\n{symbol}" else: shape = "box" diff --git a/pymc/tests/test_model_graph.py b/pymc/tests/test_model_graph.py index 1df3588465..91db13610f 100644 --- a/pymc/tests/test_model_graph.py +++ b/pymc/tests/test_model_graph.py @@ -22,6 +22,17 @@ import pymc as pm +from pymc.distributions import ( + Cauchy, + Censored, + GaussianRandomWalk, + Mixture, + Normal, + RandomWalk, + StudentT, + Truncated, + ZeroInflatedPoisson, +) from pymc.exceptions import ImputationWarning from pymc.model_graph import ModelGraph, model_to_graphviz, model_to_networkx from pymc.tests.helpers import SeededTest @@ -360,3 +371,55 @@ def test_subgraph(self, var_names, vars_to_plot, compute_graph): mg = ModelGraph(model_with_different_descendants()) assert set(mg.vars_to_plot(var_names=var_names)) == set(vars_to_plot) assert mg.make_compute_graph(var_names=var_names) == compute_graph + + +@pytest.mark.parametrize( + "symbolic_dist, dist_kwargs, display_name", + [ + (ZeroInflatedPoisson, {"psi": 0.5, "mu": 5}, "ZeroInflatedPoisson"), + ( + Censored, + {"dist": Normal.dist(Normal.dist(0.0, 5.0), 2.0), "lower": -2, "upper": 2}, + "CensoredNormal", + ), + ( + Mixture, + {"w": [0.5, 0.5], "comp_dists": Normal.dist(0.0, 5.0, shape=(2,))}, + "NormalMixture", + ), + ( + Mixture, + {"w": [0.5, 0.5], "comp_dists": [Normal.dist(0.0, 5.0), StudentT.dist(7.0)]}, + "Normal-StudentTMixture", + ), + ( + Mixture, + { + "w": [0.3, 0.45, 0.25], + "comp_dists": [Normal.dist(0.0, 5.0), StudentT.dist(7.0), Cauchy.dist(1.0, 1.0)], + }, + "MarginalMixture", + ), + ( + GaussianRandomWalk, + {"init_dist": Normal.dist(0.0, 5.0), "steps": 10}, + "GaussianRandomWalk", + ), + (Truncated, {"dist": StudentT.dist(7), "upper": 3.0}, "TruncatedStudentT"), + ( + RandomWalk, + { + "innovation_dist": pm.StudentT.dist(7), + "init_dist": pm.Normal.dist(0, 1), + "steps": 10, + }, + "StudentTRandomWalk", + ), + ], +) +def test_symbolic_distribution_display(symbolic_dist, dist_kwargs, display_name): + with pm.Model() as model: + symbolic_dist("x", **dist_kwargs) + + graph = model_to_graphviz(model) + assert display_name in graph.source