Skip to content

Commit f9057c4

Browse files
committed
Add more tests and clean up a bit
1 parent 8370d97 commit f9057c4

File tree

4 files changed

+205
-48
lines changed

4 files changed

+205
-48
lines changed

pymc/distributions/distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ class SymbolicRandomVariable(OpFromGraph):
197197
"""
198198

199199
_print_name: Tuple[str, str] = ("Unknown", "\\operatorname{Unknown}")
200-
"""Tuple of (name, latex name) used for for pretty-printing variables of this type"""
200+
"Tuple of (name, latex name) used for for pretty-printing variables of this type"
201201

202202
def __init__(self, *args, ndim_supp, **kwargs):
203203
self.ndim_supp = ndim_supp

pymc/distributions/multivariate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,7 +1116,7 @@ class _LKJCholeskyCovBaseRV(RandomVariable):
11161116
ndim_supp = 1
11171117
ndims_params = [0, 0, 1]
11181118
dtype = "floatX"
1119-
_print_name = ("_lkjcholeskycovbase", "\\operatorname{_lkjcholeskycovbase}")
1119+
_print_name = ("_lkjcholeskycovbase", r"\operatorname{\_lkjcholeskycovbase}")
11201120

11211121
def make_node(self, rng, size, dtype, n, eta, D):
11221122
n = at.as_tensor_variable(n)
@@ -1164,7 +1164,7 @@ def rng_fn(self, rng, n, eta, D, size):
11641164
# be safely resized. Because of this, we add the thin SymbolicRandomVariable wrapper
11651165
class _LKJCholeskyCovRV(SymbolicRandomVariable):
11661166
default_output = 1
1167-
_print_name = ("_lkjcholeskycov", "\\operatorname{_lkjcholeskycov}")
1167+
_print_name = ("_lkjcholeskycov", r"\operatorname{\_lkjcholeskycov}")
11681168

11691169
def update(self, node):
11701170
return {node.inputs[0]: node.outputs[0]}

pymc/printing.py

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ def str_for_model_var(
4141
4242
Intended for Distribution, Deterministic, and Potential.
4343
"""
44+
if not (
45+
_has_owner(var) and isinstance(var.owner.op, (RandomVariable, SymbolicRandomVariable))
46+
) and not _is_potential_or_deterministic(var):
47+
raise ValueError(
48+
f"Variable for pretty-printing must be a model variable or the output of .dist(). Received unsupported variable {var}"
49+
)
4450
var_name, dist_name, args_str = _get_varname_distname_args(
4551
var, formatting=formatting, dist_name=dist_name
4652
)
@@ -56,6 +62,9 @@ def str_for_model_var(
5662
if formatting == "latex":
5763
out = rf"${var_name} \sim {dist_name}({args_str})$"
5864
elif formatting == "plain":
65+
var_name = var_name.replace("~", "-")
66+
dist_name = dist_name.replace("~", "-")
67+
args_str = args_str.replace("~", "-")
5968
out = f"{var_name} ~ {dist_name}({args_str})"
6069
else:
6170
raise ValueError(
@@ -72,7 +81,7 @@ def str_for_model(model: Model, formatting: str = "plain", **kwargs) -> str:
7281
rv_reprs = [rv.str_repr(formatting=formatting, **kwargs) for rv in all_rv]
7382
if not rv_reprs:
7483
return ""
75-
if "latex" in formatting:
84+
if formatting == "latex":
7685
rv_reprs = [rv_repr.replace(r"\sim", r"&\sim &").strip("$") for rv_repr in rv_reprs]
7786
return r"""$$
7887
\begin{{array}}{{rcl}}
@@ -98,36 +107,45 @@ def _get_varname_distname_args(
98107
) -> Tuple[str, str, str]:
99108
"""Generate formatted strings for the name, distribution name, and
100109
arguments list of a Model variable.
110+
111+
For Distribution, Potential, Deterministic, or .dist().
101112
"""
102113
# Name and distribution name
103-
name = var.name if var.name is not None else "<unnamed>"
104-
if not dist_name and hasattr(var.owner.op, "_print_name"):
114+
name = var.name if var.name is not None else "<unnamed>" # May be missing if from a dist()
115+
if (
116+
not dist_name
117+
and _has_owner(var)
118+
and hasattr(var.owner.op, "_print_name")
119+
and var.owner.op._print_name
120+
):
105121
# The _print_name tuple is necessary for maximum prettiness because a few RVs
106122
# use special formatting (e.g. superscripts) for their latex print name
107123
dist_name = (
108124
var.owner.op._print_name[1] if formatting == "latex" else var.owner.op._print_name[0]
109125
)
110126
elif not dist_name:
111-
dist_name = "Unknown"
127+
raise ValueError(
128+
f"Missing distribution name for model variable: {var}. Provide one via the"
129+
" _print_name attribute of your RandomVariable."
130+
)
112131
if formatting == "latex":
113132
name = _latex_clean_command(name, command="text")
114133
dist_name = _latex_clean_command(dist_name, command="operatorname")
134+
115135
# Arguments passed to the distribution or expression
116-
if isinstance(var.owner.op, RandomVariable):
117-
# var is the RV from a Distribution.
136+
if _has_owner(var) and isinstance(var.owner.op, RandomVariable):
137+
# var is the RV or dist() from a Distribution.
118138
dist_args = var.owner.inputs[3:] # First 3 inputs are always rng, size, dtype
119-
elif isinstance(var.owner.op, SymbolicRandomVariable):
139+
elif _has_owner(var) and isinstance(var.owner.op, SymbolicRandomVariable):
120140
# var is a symbolic RV from a Distribution.
121141
dist_args = [
122142
x
123143
for x in var.owner.inputs
124144
if not isinstance(x, (RandomStateSharedVariable, RandomGeneratorSharedVariable))
125145
]
126-
elif _is_potential_or_deterministic(var):
127-
# var is a Deterministic or a Potential.
128-
dist_args = _walk_expression_args(var)
129146
else:
130-
raise ValueError(f"Unable to parse arguments for variable")
147+
# Assume that var is a Deterministic or a Potential.
148+
dist_args = _walk_expression_args(var)
131149
args_str = _str_for_args_list(dist_args, formatting=formatting)
132150
if _is_potential_or_deterministic(var):
133151
args_str = f"f({args_str})" # TODO do we still want to do this?
@@ -153,32 +171,37 @@ def _str_for_input_var(var: Variable, formatting: str) -> str:
153171
if var_data.size == 1:
154172
return f"{var_data.flatten()[0]:.3g}"
155173
else:
156-
return f"<{var_type} {var_data.shape}>" # TODO shape info or nah?
157-
elif isinstance(var.owner.op, DimShuffle):
158-
# Recurse
159-
return _str_for_input_var(var.owner.inputs[0], formatting=formatting)
160-
elif _is_potential_or_deterministic(var) or isinstance(
161-
var.owner.op, (RandomVariable, SymbolicRandomVariable)
162-
):
163-
if var.name:
174+
return f"<{var_type} {var_data.shape}>"
175+
elif _has_owner(var):
176+
if isinstance(var.owner.op, DimShuffle):
177+
# Recurse
178+
return _str_for_input_var(var.owner.inputs[0], formatting=formatting)
179+
elif _is_potential_or_deterministic(var) or isinstance(
180+
var.owner.op, (RandomVariable, SymbolicRandomVariable)
181+
):
164182
# Give the name of the RV/Potential/Deterministic if available
165-
return var.name
166-
else:
183+
if var.name:
184+
return var.name
167185
# But if rv comes from .dist() we print the distribution with its args
168-
_, dist_name, args_str = _get_varname_distname_args(var, formatting=formatting)
169-
return f"{dist_name}({args_str})"
170-
elif hasattr(var, "owner") and var.owner:
171-
# Return an "expression" i.e. indicate that this variable is a function of other
172-
# variables. Looks like f(arg1, ..., argN). Previously _str_for_expression()
173-
args = _walk_expression_args(var)
174-
args_str = _str_for_args_list(args, formatting=formatting)
175-
return f"f({args_str})"
186+
else:
187+
_, dist_name, args_str = _get_varname_distname_args(var, formatting=formatting)
188+
return f"{dist_name}({args_str})"
189+
else:
190+
# Return an "expression" i.e. indicate that this variable is a function of other
191+
# variables. Looks like f(arg1, ..., argN). Previously _str_for_expression()
192+
args = _walk_expression_args(var)
193+
args_str = _str_for_args_list(args, formatting=formatting)
194+
return f"f({args_str})"
176195
else:
177-
raise ValueError("Unidentified variable in dist or expression args")
196+
raise ValueError(
197+
f"Unidentified variable in dist or expression args: {var}. If you think this is a bug, please create an issue in the project Github."
198+
)
178199

179200

180201
def _walk_expression_args(var: Variable) -> List[Variable]:
181202
"""Find all arguments of an expression"""
203+
if not var.owner:
204+
return []
182205

183206
def _expand(x):
184207
if x.owner and (not isinstance(x.owner.op, (RandomVariable, SymbolicRandomVariable))):
@@ -210,7 +233,10 @@ def _str_for_args_list(args: List[Variable], formatting: str) -> str:
210233

211234
def _latex_clean_command(text: str, command: str) -> str:
212235
r"""Prepare text for LaTeX and maybe wrap it in a \command{}."""
213-
text = text.replace("$", r"\$") # TODO do we want to keep dollar signs or strip them?
236+
text = text.replace("$", r"\$")
237+
# str_for_model() uses \sim to format the array, and properly
238+
# tilde in latex is hard. So we replace for simplicity
239+
text = text.replace("~", "-")
214240
if not text.startswith(rf"\{command}"):
215241
# The printing module is designed such that text never passes through this
216242
# function more than once. However, in some cases the text may have already
@@ -225,14 +251,11 @@ def _latex_clean_command(text: str, command: str) -> str:
225251
# command itself, writing the character, then continuing on with the same command.
226252
if command == "text":
227253
text = text.replace("_", rf"}}\_\{command}{{")
228-
text = text.replace("~", rf"}}~\{command}{{")
229254
return text
230255

231256

232257
def _is_potential_or_deterministic(var: Variable) -> bool:
233-
# This is a bit hacky but seems like the best we got. We should write
234-
# a test to make sure that Deterministic and Potential don't get updated
235-
# without also modifying this function.
258+
# This is a bit hacky but seems like the best we got
236259
if (
237260
hasattr(var, "str_repr")
238261
and callable(var.str_repr)
@@ -243,6 +266,10 @@ def _is_potential_or_deterministic(var: Variable) -> bool:
243266
return False
244267

245268

269+
def _has_owner(var: Variable):
270+
return hasattr(var, "owner") and var.owner
271+
272+
246273
def _pymc_pprint(obj: Union[TensorVariable, Model], *args, **kwargs):
247274
"""Pretty-print method that instructs IPython to use our `str_repr()`.
248275

0 commit comments

Comments
 (0)