diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 0dc3bb6efb..91fe631798 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -13,12 +13,13 @@ # limitations under the License. import warnings -from collections import defaultdict, deque -from typing import Dict, Iterator, NewType, Optional, Set +from collections import defaultdict +from typing import Dict, Iterable, List, NewType, Optional, Set from aesara import function from aesara.compile.sharedvalue import SharedVariable -from aesara.graph.basic import walk +from aesara.graph import Apply +from aesara.graph.basic import ancestors, walk from aesara.tensor.random.op import RandomVariable from aesara.tensor.var import TensorConstant, TensorVariable @@ -32,85 +33,64 @@ class ModelGraph: def __init__(self, model): self.model = model - self.var_names = get_default_varnames(self.model.named_vars, include_transformed=False) + self._all_var_names = get_default_varnames(self.model.named_vars, include_transformed=False) self.var_list = self.model.named_vars.values() - self.transform_map = { - v.transformed: v.name for v in self.var_list if hasattr(v, "transformed") - } - self._deterministics = None - - def get_deterministics(self, var): - """Compute the deterministic nodes of the graph, **not** including var itself.""" - deterministics = [] - attrs = ("transformed", "logpt") - for v in self.var_list: - if v != var and all(not hasattr(v, attr) for attr in attrs): - deterministics.append(v) - return deterministics - - def _get_ancestors(self, var: TensorVariable, func) -> Set[TensorVariable]: - """Get all ancestors of a function, doing some accounting for deterministics.""" - - # this contains all of the variables in the model EXCEPT var... - vars = set(self.var_list) - vars.remove(var) - - blockers = set() # type: Set[TensorVariable] - retval = set() # type: Set[TensorVariable] - - def _expand(node) -> Optional[Iterator[TensorVariable]]: - if node in blockers: - return None - elif node in vars: - blockers.add(node) - retval.add(node) - return None - elif node.owner: - blockers.add(node) - return reversed(node.owner.inputs) - else: - return None - - list(walk(deque([func]), _expand, bfs=True)) - return retval - - def _filter_parents(self, var, parents) -> Set[VarName]: - """Get direct parents of a var, as strings""" - keep = set() # type: Set[VarName] - for p in parents: - if p == var: - continue - elif p.name in self.var_names: - keep.add(p.name) - elif p in self.transform_map: - if self.transform_map[p] != var.name: - keep.add(self.transform_map[p]) - else: - raise AssertionError(f"Do not know what to do with {get_var_name(p)}") - return keep - - def get_parents(self, var: TensorVariable) -> Set[VarName]: - """Get the named nodes that are direct inputs to the var""" - # TODO: Update these lines, variables no longer have a `logpt` attribute - if hasattr(var, "transformed"): - func = var.transformed.logpt - elif hasattr(var, "logpt"): - func = var.logpt - else: - func = var - parents = self._get_ancestors(var, func) - return self._filter_parents(var, parents) + def get_parent_names(self, var: TensorVariable) -> Set[VarName]: + if var.owner is None or var.owner.inputs is None: + return set() + + def _expand(x): + if x.name: + return [x] + if isinstance(x.owner, Apply): + return reversed(x.owner.inputs) + return [] + + parents = {get_var_name(x) for x in walk(nodes=var.owner.inputs, expand=_expand) if x.name} + + return parents + + def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[VarName]: + if var_names is None: + return self._all_var_names + + selected_names = set(var_names) + + # .copy() because sets cannot change in size during iteration + for var_name in selected_names.copy(): + if var_name not in self._all_var_names: + raise ValueError(f"{var_name} is not in this model.") + + for model_var in self.var_list: + if hasattr(model_var.tag, "observations"): + if model_var.tag.observations == self.model[var_name]: + selected_names.add(model_var.name) - def make_compute_graph(self) -> Dict[str, Set[VarName]]: + selected_ancestors = set( + filter( + lambda rv: rv.name in self._all_var_names, + list(ancestors([self.model[var_name] for var_name in selected_names])), + ) + ) + + for var in selected_ancestors.copy(): + if hasattr(var.tag, "observations"): + selected_ancestors.add(var.tag.observations) + + # ordering of self._all_var_names is important + return [var.name for var in selected_ancestors] + + def make_compute_graph( + self, var_names: Optional[Iterable[VarName]] = None + ) -> Dict[VarName, Set[VarName]]: """Get map of var_name -> set(input var names) for the model""" - input_map = defaultdict(set) # type: Dict[str, Set[VarName]] + input_map: Dict[VarName, Set[VarName]] = defaultdict(set) - for var_name in self.var_names: + for var_name in self.vars_to_plot(var_names): var = self.model[var_name] - key = var_name - val = self.get_parents(var) - input_map[key] = input_map[key].union(val) + parent_name = self.get_parent_names(var) + input_map[var_name] = input_map[var_name].union(parent_name) if hasattr(var.tag, "observations"): try: @@ -120,6 +100,7 @@ def make_compute_graph(self) -> Dict[str, Set[VarName]]: input_map[obs_name] = input_map[obs_name].union({var_name}) except AttributeError: pass + return input_map def _make_node(self, var_name, graph, *, formatting: str = "plain"): @@ -168,7 +149,7 @@ def _make_node(self, var_name, graph, *, formatting: str = "plain"): def _eval(self, var): return function([], var, mode="FAST_COMPILE")() - def get_plates(self): + def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str, Set[VarName]]: """Rough but surprisingly accurate plate detection. Just groups by the shape of the underlying distribution. Will be wrong @@ -176,10 +157,12 @@ def get_plates(self): Returns ------- - dict: str -> set[str] + dict + Maps plate labels to the set of ``VarName``s inside the plate. """ plates = defaultdict(set) - for var_name in self.var_names: + + for var_name in self.vars_to_plot(var_names): v = self.model[var_name] if var_name in self.model.RV_dims: plate_label = " x ".join( @@ -189,9 +172,10 @@ def get_plates(self): else: plate_label = " x ".join(map(str, self._eval(v.shape))) plates[plate_label].add(var_name) + return plates - def make_graph(self, formatting: str = "plain"): + def make_graph(self, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"): """Make graphviz Digraph of PyMC model Returns @@ -207,25 +191,29 @@ def make_graph(self, formatting: str = "plain"): "\tconda install -c conda-forge python-graphviz" ) graph = graphviz.Digraph(self.model.name) - for plate_label, var_names in self.get_plates().items(): + for plate_label, all_var_names in self.get_plates(var_names).items(): if plate_label: # must be preceded by 'cluster' to get a box around it with graph.subgraph(name="cluster" + plate_label) as sub: - for var_name in var_names: + for var_name in all_var_names: self._make_node(var_name, sub, formatting=formatting) # plate label goes bottom right sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded") else: - for var_name in var_names: + for var_name in all_var_names: self._make_node(var_name, graph, formatting=formatting) - for key, values in self.make_compute_graph().items(): - for value in values: - graph.edge(value.replace(":", "&"), key.replace(":", "&")) + for child, parents in self.make_compute_graph(var_names=var_names).items(): + # parents is a set of rv names that preceed child rv nodes + for parent in parents: + graph.edge(parent.replace(":", "&"), child.replace(":", "&")) + return graph -def model_to_graphviz(model=None, *, formatting: str = "plain"): +def model_to_graphviz( + model=None, *, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain" +): """Produce a graphviz Digraph from a PyMC model. Requires graphviz, which may be installed most easily with @@ -240,7 +228,9 @@ def model_to_graphviz(model=None, *, formatting: str = "plain"): ---------- model : pm.Model The model to plot. Not required when called from inside a modelcontext. - formatting : str + var_names : iterable of variable names, optional + Subset of variables to be plotted that identify a subgraph with respect to the entire model graph + formatting : str, optional one of { "plain" } Examples @@ -275,4 +265,4 @@ def model_to_graphviz(model=None, *, formatting: str = "plain"): "Formattings other than 'plain' are currently not supported.", UserWarning, stacklevel=2 ) model = pm.modelcontext(model) - return ModelGraph(model).make_graph(formatting=formatting) + return ModelGraph(model).make_graph(var_names=var_names, formatting=formatting) diff --git a/pymc/tests/test_model_graph.py b/pymc/tests/test_model_graph.py index 14e3e0ebec..2b85d181a9 100644 --- a/pymc/tests/test_model_graph.py +++ b/pymc/tests/test_model_graph.py @@ -143,7 +143,7 @@ def setup_class(cls): def test_inputs(self): for child, parents_in_plot in self.compute_graph.items(): var = self.model[child] - parents_in_graph = self.model_graph.get_parents(var) + parents_in_graph = self.model_graph.get_parent_names(var) if isinstance(var, SharedVariable): # observed data also doesn't have parents in the compute graph! # But for the visualization we like them to become decendants of the @@ -183,6 +183,85 @@ def test_checks_formatting(self): model_to_graphviz(self.model, formatting="plain_with_params") +def model_with_different_descendants(): + """ + Model proposed by Michael to test variable selection functionality + From here: https://github.com/pymc-devs/pymc/pull/5634#pullrequestreview-916297509 + """ + with pm.Model() as pmodel2: + a = pm.Normal("a") + b = pm.Normal("b") + pm.Normal("c", a * b) + intermediate = pm.Deterministic("intermediate", a + b) + pred = pm.Deterministic("pred", intermediate * 3) + + obs = pm.ConstantData("obs", 1.75) + + L = pm.Normal("L", mu=1 + 0.5 * pred, observed=obs) + + return pmodel2 + + +class TestParents: + @pytest.mark.parametrize( + "var_name, parent_names", + [ + ("L", {"pred"}), + ("pred", {"intermediate"}), + ("intermediate", {"a", "b"}), + ("c", {"a", "b"}), + ("a", set()), + ("b", set()), + ], + ) + def test_get_parent_names(self, var_name, parent_names): + mg = ModelGraph(model_with_different_descendants()) + mg.get_parent_names(mg.model[var_name]) == parent_names + + +class TestVariableSelection: + @pytest.mark.parametrize( + "var_names, vars_to_plot, compute_graph", + [ + (["c"], ["a", "b", "c"], {"c": {"a", "b"}, "a": set(), "b": set()}), + ( + ["L"], + ["pred", "obs", "L", "intermediate", "a", "b"], + { + "pred": {"intermediate"}, + "obs": {"L"}, + "L": {"pred"}, + "intermediate": {"a", "b"}, + "a": set(), + "b": set(), + }, + ), + ( + ["obs"], + ["pred", "obs", "L", "intermediate", "a", "b"], + { + "pred": {"intermediate"}, + "obs": {"L"}, + "L": {"pred"}, + "intermediate": {"a", "b"}, + "a": set(), + "b": set(), + }, + ), + # selecting ["c", "L"] is akin to selecting the entire graph + ( + ["c", "L"], + ModelGraph(model_with_different_descendants()).vars_to_plot(), + ModelGraph(model_with_different_descendants()).make_compute_graph(), + ), + ], + ) + def test_subgraph(self, var_names, vars_to_plot, compute_graph): + mg = ModelGraph(model_with_different_descendants()) + assert set(mg.vars_to_plot(var_names=var_names)) == set(vars_to_plot) + assert mg.make_compute_graph(var_names=var_names) == compute_graph + + class TestImputationModel(BaseModelGraphTest): model_func = model_with_imputations