@@ -181,14 +181,11 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st
181
181
style = "filled"
182
182
else :
183
183
style = None
184
- symbol = v .owner .op .__class__ . __name__
185
- if symbol == "MarginalMixtureRV " :
184
+ symbol = v .owner .op ._print_name [ 0 ]
185
+ if symbol == "MarginalMixture " :
186
186
components = v .owner .inputs [2 :]
187
187
if len (components ) == 2 :
188
- component_names = [
189
- var .owner .op .__class__ .__name__ .replace ("Unmeasurable" , "" )[:- 2 ]
190
- for var in components
191
- ]
188
+ component_names = [var .owner .op ._print_name [0 ] for var in components ]
192
189
if check_zip_graph_from_components (components ):
193
190
# ZeroInflated distribution
194
191
component_names .remove ("DiracDelta" )
@@ -198,17 +195,21 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st
198
195
symbol = f"{ '-' .join (component_names )} Mixture"
199
196
elif len (components ) == 1 :
200
197
# single component dispatch mixture
201
- symbol = f"{ components [0 ].owner .op .__class__ . __name__ . replace ( 'Unmeasurable' , '' )[: - 2 ]} Mixture"
198
+ symbol = f"{ components [0 ].owner .op ._print_name [ 0 ]} Mixture"
202
199
else :
203
200
symbol = symbol [:- 2 ] # just MarginalMixture
204
- elif symbol == "CensoredRV " :
201
+ elif symbol == "Censored " :
205
202
censored_dist = v .owner .inputs [0 ]
206
- symbol = symbol [: - 2 ] + censored_dist .owner .op .__class__ . __name__ [: - 2 ]
207
- elif symbol == "TruncatedRV " :
203
+ symbol = symbol + censored_dist .owner .op ._print_name [ 0 ]
204
+ elif symbol == "Truncated " :
208
205
truncated_dist = v .owner .op .base_rv_op
209
- symbol = symbol [:- 2 ] + truncated_dist .__class__ .__name__ [:- 2 ]
210
- elif symbol .endswith ("RV" ):
211
- symbol = symbol [:- 2 ]
206
+ symbol = symbol + truncated_dist ._print_name [0 ]
207
+ elif symbol == "RandomWalk" :
208
+ innovation_dist = v .owner .inputs [1 ].owner .op ._print_name [0 ]
209
+ if innovation_dist == "Normal" :
210
+ symbol = "Gaussian" + symbol
211
+ else :
212
+ symbol = innovation_dist + symbol
212
213
label = f"{ var_name } \n ~\n { symbol } "
213
214
else :
214
215
shape = "box"
0 commit comments