Skip to content

Commit edf40bf

Browse files
Added docstring and type hints
1 parent c9cb329 commit edf40bf

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

pymc/model_graph.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from aesara.graph import Apply
2222
from aesara.graph.basic import ancestors, walk
2323
from aesara.tensor.random.op import RandomVariable
24-
from aesara.tensor.var import TensorConstant
24+
from aesara.tensor.var import TensorConstant, TensorVariable
2525

2626
import pymc as pm
2727

@@ -36,7 +36,7 @@ def __init__(self, model):
3636
self._all_var_names = get_default_varnames(self.model.named_vars, include_transformed=False)
3737
self.var_list = self.model.named_vars.values()
3838

39-
def get_parent_names(self, var):
39+
def get_parent_names(self, var: TensorVariable) -> Set[VarName]:
4040
if var.owner is None or var.owner.inputs is None:
4141
return set()
4242

@@ -51,7 +51,7 @@ def _expand(x):
5151

5252
return parents
5353

54-
def vars_to_plot(self, var_names: Optional[Iterable[str]] = None) -> List[str]:
54+
def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[VarName]:
5555
if var_names is None:
5656
return self._all_var_names
5757

@@ -82,10 +82,10 @@ def vars_to_plot(self, var_names: Optional[Iterable[str]] = None) -> List[str]:
8282
return [var.name for var in selected_ancestors]
8383

8484
def make_compute_graph(
85-
self, var_names: Optional[Iterable[str]] = None
86-
) -> Dict[str, Set[VarName]]:
85+
self, var_names: Optional[Iterable[VarName]] = None
86+
) -> Dict[VarName, Set[VarName]]:
8787
"""Get map of var_name -> set(input var names) for the model"""
88-
input_map = defaultdict(set) # type: Dict[str, Set[VarName]]
88+
input_map = defaultdict(set) # type: Dict[VarName, Set[VarName]]
8989

9090
for var_name in self.vars_to_plot(var_names):
9191
var = self.model[var_name]
@@ -149,15 +149,15 @@ def _make_node(self, var_name, graph, *, formatting: str = "plain"):
149149
def _eval(self, var):
150150
return function([], var, mode="FAST_COMPILE")()
151151

152-
def get_plates(self, var_names: Optional[Iterable[str]] = None):
152+
def get_plates(self, var_names: Optional[Iterable[VarName]] = None):
153153
"""Rough but surprisingly accurate plate detection.
154154
155155
Just groups by the shape of the underlying distribution. Will be wrong
156156
if there are two plates with the same shape.
157157
158158
Returns
159159
-------
160-
dict: str -> set[str]
160+
dict: VarName -> set(VarName)
161161
"""
162162
plates = defaultdict(set)
163163

@@ -174,7 +174,7 @@ def get_plates(self, var_names: Optional[Iterable[str]] = None):
174174

175175
return plates
176176

177-
def make_graph(self, var_names=None, formatting: str = "plain"):
177+
def make_graph(self, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"):
178178
"""Make graphviz Digraph of PyMC model
179179
180180
Returns
@@ -211,7 +211,7 @@ def make_graph(self, var_names=None, formatting: str = "plain"):
211211

212212

213213
def model_to_graphviz(
214-
model=None, *, var_names: Optional[Iterable[str]] = None, formatting: str = "plain"
214+
model=None, *, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"
215215
):
216216
"""Produce a graphviz Digraph from a PyMC model.
217217
@@ -227,7 +227,9 @@ def model_to_graphviz(
227227
----------
228228
model : pm.Model
229229
The model to plot. Not required when called from inside a modelcontext.
230-
formatting : str
230+
var_names : iterable of variable names, optional
231+
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
232+
formatting : str, optional
231233
one of { "plain" }
232234
233235
Examples

0 commit comments

Comments
 (0)