Skip to content

Commit d65db13

Browse files
committed
Make Model.str_repr robust to variables without monkey-patch
1 parent df7b267 commit d65db13

File tree

2 files changed

+37
-15
lines changed

2 files changed

+37
-15
lines changed

pymc/printing.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import itertools
16-
1715
from typing import Union
1816

1917
from pytensor.compile import SharedVariable
@@ -98,36 +96,46 @@ def str_for_dist(
9896
def str_for_model(model: Model, formatting: str = "plain", include_params: bool = True) -> str:
9997
"""Make a human-readable string representation of Model, listing all random variables
10098
and their distributions, optionally including parameter values."""
101-
all_rv = itertools.chain(model.unobserved_RVs, model.observed_RVs, model.potentials)
10299

103-
rv_reprs = [rv.str_repr(formatting=formatting, include_params=include_params) for rv in all_rv]
104-
rv_reprs = [rv_repr for rv_repr in rv_reprs if "TransformedDistribution()" not in rv_repr]
100+
kwargs = dict(formatting=formatting, include_params=include_params)
101+
free_rv_reprs = [str_for_dist(dist, **kwargs) for dist in model.free_RVs]
102+
observed_rv_reprs = [str_for_dist(rv, **kwargs) for rv in model.observed_RVs]
103+
det_reprs = [
104+
str_for_potential_or_deterministic(dist, **kwargs, dist_name="Deterministic")
105+
for dist in model.deterministics
106+
]
107+
potential_reprs = [
108+
str_for_potential_or_deterministic(pot, **kwargs, dist_name="Potential")
109+
for pot in model.potentials
110+
]
111+
112+
var_reprs = free_rv_reprs + det_reprs + observed_rv_reprs + potential_reprs
105113

106-
if not rv_reprs:
114+
if not var_reprs:
107115
return ""
108116
if "latex" in formatting:
109-
rv_reprs = [
110-
rv_repr.replace(r"\sim", r"&\sim &").strip("$")
111-
for rv_repr in rv_reprs
112-
if rv_repr is not None
117+
var_reprs = [
118+
var_repr.replace(r"\sim", r"&\sim &").strip("$")
119+
for var_repr in var_reprs
120+
if var_repr is not None
113121
]
114122
return r"""$$
115123
\begin{{array}}{{rcl}}
116124
{}
117125
\end{{array}}
118126
$$""".format(
119-
"\\\\".join(rv_reprs)
127+
"\\\\".join(var_reprs)
120128
)
121129
else:
122130
# align vars on their ~
123-
names = [s[: s.index("~") - 1] for s in rv_reprs]
124-
distrs = [s[s.index("~") + 2 :] for s in rv_reprs]
131+
names = [s[: s.index("~") - 1] for s in var_reprs]
132+
distrs = [s[s.index("~") + 2 :] for s in var_reprs]
125133
maxlen = str(max(len(x) for x in names))
126-
rv_reprs = [
134+
var_reprs = [
127135
("{name:>" + maxlen + "} ~ {distr}").format(name=n, distr=d)
128136
for n, d in zip(names, distrs)
129137
]
130-
return "\n".join(rv_reprs)
138+
return "\n".join(var_reprs)
131139

132140

133141
def str_for_potential_or_deterministic(

tests/test_printing.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
import numpy as np
1515

16+
from pytensor.tensor.random import normal
17+
1618
from pymc import Bernoulli, Censored, HalfCauchy, Mixture, StudentT
1719
from pymc.distributions import (
1820
Dirichlet,
@@ -274,3 +276,15 @@ def test_model_latex_repr_mixture_model():
274276
"$$",
275277
]
276278
assert [line.strip() for line in latex_repr.split("\n")] == expected
279+
280+
281+
def test_model_repr_variables_without_monkey_patched_repr():
282+
"""Test that model repr does not rely on individual variables having the str_repr method monkey patched."""
283+
x = normal(name="x")
284+
assert not hasattr(x, "str_repr")
285+
286+
model = Model()
287+
model.register_rv(x, "x")
288+
289+
str_repr = model.str_repr()
290+
assert str_repr == "x ~ Normal(0, 1)"

0 commit comments

Comments
 (0)