21
21
from aesara .graph import Apply
22
22
from aesara .graph .basic import ancestors , walk
23
23
from aesara .tensor .random .op import RandomVariable
24
- from aesara .tensor .var import TensorConstant
24
+ from aesara .tensor .var import TensorConstant , TensorVariable
25
25
26
26
import pymc as pm
27
27
@@ -36,7 +36,7 @@ def __init__(self, model):
36
36
self ._all_var_names = get_default_varnames (self .model .named_vars , include_transformed = False )
37
37
self .var_list = self .model .named_vars .values ()
38
38
39
- def get_parent_names (self , var ) :
39
+ def get_parent_names (self , var : TensorVariable ) -> Set [ VarName ] :
40
40
if var .owner is None or var .owner .inputs is None :
41
41
return set ()
42
42
@@ -51,7 +51,7 @@ def _expand(x):
51
51
52
52
return parents
53
53
54
- def vars_to_plot (self , var_names : Optional [Iterable [str ]] = None ) -> List [str ]:
54
+ def vars_to_plot (self , var_names : Optional [Iterable [VarName ]] = None ) -> List [VarName ]:
55
55
if var_names is None :
56
56
return self ._all_var_names
57
57
@@ -82,10 +82,10 @@ def vars_to_plot(self, var_names: Optional[Iterable[str]] = None) -> List[str]:
82
82
return [var .name for var in selected_ancestors ]
83
83
84
84
def make_compute_graph (
85
- self , var_names : Optional [Iterable [str ]] = None
86
- ) -> Dict [str , Set [VarName ]]:
85
+ self , var_names : Optional [Iterable [VarName ]] = None
86
+ ) -> Dict [VarName , Set [VarName ]]:
87
87
"""Get map of var_name -> set(input var names) for the model"""
88
- input_map = defaultdict (set ) # type: Dict[str , Set[VarName]]
88
+ input_map = defaultdict (set ) # type: Dict[VarName , Set[VarName]]
89
89
90
90
for var_name in self .vars_to_plot (var_names ):
91
91
var = self .model [var_name ]
@@ -149,15 +149,15 @@ def _make_node(self, var_name, graph, *, formatting: str = "plain"):
149
149
def _eval (self , var ):
150
150
return function ([], var , mode = "FAST_COMPILE" )()
151
151
152
- def get_plates (self , var_names : Optional [Iterable [str ]] = None ):
152
+ def get_plates (self , var_names : Optional [Iterable [VarName ]] = None ):
153
153
"""Rough but surprisingly accurate plate detection.
154
154
155
155
Just groups by the shape of the underlying distribution. Will be wrong
156
156
if there are two plates with the same shape.
157
157
158
158
Returns
159
159
-------
160
- dict: str -> set[str]
160
+ dict: VarName -> set(VarName)
161
161
"""
162
162
plates = defaultdict (set )
163
163
@@ -174,7 +174,7 @@ def get_plates(self, var_names: Optional[Iterable[str]] = None):
174
174
175
175
return plates
176
176
177
- def make_graph (self , var_names = None , formatting : str = "plain" ):
177
+ def make_graph (self , var_names : Optional [ Iterable [ VarName ]] = None , formatting : str = "plain" ):
178
178
"""Make graphviz Digraph of PyMC model
179
179
180
180
Returns
@@ -211,7 +211,7 @@ def make_graph(self, var_names=None, formatting: str = "plain"):
211
211
212
212
213
213
def model_to_graphviz (
214
- model = None , * , var_names : Optional [Iterable [str ]] = None , formatting : str = "plain"
214
+ model = None , * , var_names : Optional [Iterable [VarName ]] = None , formatting : str = "plain"
215
215
):
216
216
"""Produce a graphviz Digraph from a PyMC model.
217
217
@@ -227,7 +227,9 @@ def model_to_graphviz(
227
227
----------
228
228
model : pm.Model
229
229
The model to plot. Not required when called from inside a modelcontext.
230
- formatting : str
230
+ var_names : iterable of variable names, optional
231
+ Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
232
+ formatting : str, optional
231
233
one of { "plain" }
232
234
233
235
Examples
0 commit comments