@@ -41,6 +41,12 @@ def str_for_model_var(
41
41
42
42
Intended for Distribution, Deterministic, and Potential.
43
43
"""
44
+ if not (
45
+ _has_owner (var ) and isinstance (var .owner .op , (RandomVariable , SymbolicRandomVariable ))
46
+ ) and not _is_potential_or_deterministic (var ):
47
+ raise ValueError (
48
+ f"Variable for pretty-printing must be a model variable or the output of .dist(). Received unsupported variable { var } "
49
+ )
44
50
var_name , dist_name , args_str = _get_varname_distname_args (
45
51
var , formatting = formatting , dist_name = dist_name
46
52
)
@@ -56,6 +62,9 @@ def str_for_model_var(
56
62
if formatting == "latex" :
57
63
out = rf"${ var_name } \sim { dist_name } ({ args_str } )$"
58
64
elif formatting == "plain" :
65
+ var_name = var_name .replace ("~" , "-" )
66
+ dist_name = dist_name .replace ("~" , "-" )
67
+ args_str = args_str .replace ("~" , "-" )
59
68
out = f"{ var_name } ~ { dist_name } ({ args_str } )"
60
69
else :
61
70
raise ValueError (
@@ -72,7 +81,7 @@ def str_for_model(model: Model, formatting: str = "plain", **kwargs) -> str:
72
81
rv_reprs = [rv .str_repr (formatting = formatting , ** kwargs ) for rv in all_rv ]
73
82
if not rv_reprs :
74
83
return ""
75
- if "latex" in formatting :
84
+ if formatting == "latex" :
76
85
rv_reprs = [rv_repr .replace (r"\sim" , r"&\sim &" ).strip ("$" ) for rv_repr in rv_reprs ]
77
86
return r"""$$
78
87
\begin{{array}}{{rcl}}
@@ -98,36 +107,45 @@ def _get_varname_distname_args(
98
107
) -> Tuple [str , str , str ]:
99
108
"""Generate formatted strings for the name, distribution name, and
100
109
arguments list of a Model variable.
110
+
111
+ For Distribution, Potential, Deterministic, or .dist().
101
112
"""
102
113
# Name and distribution name
103
- name = var .name if var .name is not None else "<unnamed>"
104
- if not dist_name and hasattr (var .owner .op , "_print_name" ):
114
+ name = var .name if var .name is not None else "<unnamed>" # May be missing if from a dist()
115
+ if (
116
+ not dist_name
117
+ and _has_owner (var )
118
+ and hasattr (var .owner .op , "_print_name" )
119
+ and var .owner .op ._print_name
120
+ ):
105
121
# The _print_name tuple is necessary for maximum prettiness because a few RVs
106
122
# use special formatting (e.g. superscripts) for their latex print name
107
123
dist_name = (
108
124
var .owner .op ._print_name [1 ] if formatting == "latex" else var .owner .op ._print_name [0 ]
109
125
)
110
126
elif not dist_name :
111
- dist_name = "Unknown"
127
+ raise ValueError (
128
+ f"Missing distribution name for model variable: { var } . Provide one via the"
129
+ " _print_name attribute of your RandomVariable."
130
+ )
112
131
if formatting == "latex" :
113
132
name = _latex_clean_command (name , command = "text" )
114
133
dist_name = _latex_clean_command (dist_name , command = "operatorname" )
134
+
115
135
# Arguments passed to the distribution or expression
116
- if isinstance (var .owner .op , RandomVariable ):
117
- # var is the RV from a Distribution.
136
+ if _has_owner ( var ) and isinstance (var .owner .op , RandomVariable ):
137
+ # var is the RV or dist() from a Distribution.
118
138
dist_args = var .owner .inputs [3 :] # First 3 inputs are always rng, size, dtype
119
- elif isinstance (var .owner .op , SymbolicRandomVariable ):
139
+ elif _has_owner ( var ) and isinstance (var .owner .op , SymbolicRandomVariable ):
120
140
# var is a symbolic RV from a Distribution.
121
141
dist_args = [
122
142
x
123
143
for x in var .owner .inputs
124
144
if not isinstance (x , (RandomStateSharedVariable , RandomGeneratorSharedVariable ))
125
145
]
126
- elif _is_potential_or_deterministic (var ):
127
- # var is a Deterministic or a Potential.
128
- dist_args = _walk_expression_args (var )
129
146
else :
130
- raise ValueError (f"Unable to parse arguments for variable" )
147
+ # Assume that var is a Deterministic or a Potential.
148
+ dist_args = _walk_expression_args (var )
131
149
args_str = _str_for_args_list (dist_args , formatting = formatting )
132
150
if _is_potential_or_deterministic (var ):
133
151
args_str = f"f({ args_str } )" # TODO do we still want to do this?
@@ -153,32 +171,37 @@ def _str_for_input_var(var: Variable, formatting: str) -> str:
153
171
if var_data .size == 1 :
154
172
return f"{ var_data .flatten ()[0 ]:.3g} "
155
173
else :
156
- return f"<{ var_type } { var_data .shape } >" # TODO shape info or nah?
157
- elif isinstance (var . owner . op , DimShuffle ):
158
- # Recurse
159
- return _str_for_input_var ( var . owner . inputs [ 0 ], formatting = formatting )
160
- elif _is_potential_or_deterministic ( var ) or isinstance (
161
- var . owner . op , ( RandomVariable , SymbolicRandomVariable )
162
- ):
163
- if var . name :
174
+ return f"<{ var_type } { var_data .shape } >"
175
+ elif _has_owner (var ):
176
+ if isinstance ( var . owner . op , DimShuffle ):
177
+ # Recurse
178
+ return _str_for_input_var ( var . owner . inputs [ 0 ], formatting = formatting )
179
+ elif _is_potential_or_deterministic ( var ) or isinstance (
180
+ var . owner . op , ( RandomVariable , SymbolicRandomVariable )
181
+ ) :
164
182
# Give the name of the RV/Potential/Deterministic if available
165
- return var .name
166
- else :
183
+ if var .name :
184
+ return var . name
167
185
# But if rv comes from .dist() we print the distribution with its args
168
- _ , dist_name , args_str = _get_varname_distname_args (var , formatting = formatting )
169
- return f"{ dist_name } ({ args_str } )"
170
- elif hasattr (var , "owner" ) and var .owner :
171
- # Return an "expression" i.e. indicate that this variable is a function of other
172
- # variables. Looks like f(arg1, ..., argN). Previously _str_for_expression()
173
- args = _walk_expression_args (var )
174
- args_str = _str_for_args_list (args , formatting = formatting )
175
- return f"f({ args_str } )"
186
+ else :
187
+ _ , dist_name , args_str = _get_varname_distname_args (var , formatting = formatting )
188
+ return f"{ dist_name } ({ args_str } )"
189
+ else :
190
+ # Return an "expression" i.e. indicate that this variable is a function of other
191
+ # variables. Looks like f(arg1, ..., argN). Previously _str_for_expression()
192
+ args = _walk_expression_args (var )
193
+ args_str = _str_for_args_list (args , formatting = formatting )
194
+ return f"f({ args_str } )"
176
195
else :
177
- raise ValueError ("Unidentified variable in dist or expression args" )
196
+ raise ValueError (
197
+ f"Unidentified variable in dist or expression args: { var } . If you think this is a bug, please create an issue in the project Github."
198
+ )
178
199
179
200
180
201
def _walk_expression_args (var : Variable ) -> List [Variable ]:
181
202
"""Find all arguments of an expression"""
203
+ if not var .owner :
204
+ return []
182
205
183
206
def _expand (x ):
184
207
if x .owner and (not isinstance (x .owner .op , (RandomVariable , SymbolicRandomVariable ))):
@@ -210,7 +233,10 @@ def _str_for_args_list(args: List[Variable], formatting: str) -> str:
210
233
211
234
def _latex_clean_command (text : str , command : str ) -> str :
212
235
r"""Prepare text for LaTeX and maybe wrap it in a \command{}."""
213
- text = text .replace ("$" , r"\$" ) # TODO do we want to keep dollar signs or strip them?
236
+ text = text .replace ("$" , r"\$" )
237
+ # str_for_model() uses \sim to format the array, and properly
238
+ # tilde in latex is hard. So we replace for simplicity
239
+ text = text .replace ("~" , "-" )
214
240
if not text .startswith (rf"\{ command } " ):
215
241
# The printing module is designed such that text never passes through this
216
242
# function more than once. However, in some cases the text may have already
@@ -225,14 +251,11 @@ def _latex_clean_command(text: str, command: str) -> str:
225
251
# command itself, writing the character, then continuing on with the same command.
226
252
if command == "text" :
227
253
text = text .replace ("_" , rf"}}\_\{ command } {{" )
228
- text = text .replace ("~" , rf"}}~\{ command } {{" )
229
254
return text
230
255
231
256
232
257
def _is_potential_or_deterministic (var : Variable ) -> bool :
233
- # This is a bit hacky but seems like the best we got. We should write
234
- # a test to make sure that Deterministic and Potential don't get updated
235
- # without also modifying this function.
258
+ # This is a bit hacky but seems like the best we got
236
259
if (
237
260
hasattr (var , "str_repr" )
238
261
and callable (var .str_repr )
@@ -243,6 +266,10 @@ def _is_potential_or_deterministic(var: Variable) -> bool:
243
266
return False
244
267
245
268
269
+ def _has_owner (var : Variable ):
270
+ return hasattr (var , "owner" ) and var .owner
271
+
272
+
246
273
def _pymc_pprint (obj : Union [TensorVariable , Model ], * args , ** kwargs ):
247
274
"""Pretty-print method that instructs IPython to use our `str_repr()`.
248
275
0 commit comments