diff --git a/pymc/printing.py b/pymc/printing.py index f3ce014384..102745e771 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools - from typing import Union from pytensor.compile import SharedVariable @@ -98,36 +96,46 @@ def str_for_dist( def str_for_model(model: Model, formatting: str = "plain", include_params: bool = True) -> str: """Make a human-readable string representation of Model, listing all random variables and their distributions, optionally including parameter values.""" - all_rv = itertools.chain(model.unobserved_RVs, model.observed_RVs, model.potentials) - rv_reprs = [rv.str_repr(formatting=formatting, include_params=include_params) for rv in all_rv] - rv_reprs = [rv_repr for rv_repr in rv_reprs if "TransformedDistribution()" not in rv_repr] + kwargs = dict(formatting=formatting, include_params=include_params) + free_rv_reprs = [str_for_dist(dist, **kwargs) for dist in model.free_RVs] + observed_rv_reprs = [str_for_dist(rv, **kwargs) for rv in model.observed_RVs] + det_reprs = [ + str_for_potential_or_deterministic(dist, **kwargs, dist_name="Deterministic") + for dist in model.deterministics + ] + potential_reprs = [ + str_for_potential_or_deterministic(pot, **kwargs, dist_name="Potential") + for pot in model.potentials + ] + + var_reprs = free_rv_reprs + det_reprs + observed_rv_reprs + potential_reprs - if not rv_reprs: + if not var_reprs: return "" if "latex" in formatting: - rv_reprs = [ - rv_repr.replace(r"\sim", r"&\sim &").strip("$") - for rv_repr in rv_reprs - if rv_repr is not None + var_reprs = [ + var_repr.replace(r"\sim", r"&\sim &").strip("$") + for var_repr in var_reprs + if var_repr is not None ] return r"""$$ \begin{{array}}{{rcl}} {} \end{{array}} $$""".format( - "\\\\".join(rv_reprs) + "\\\\".join(var_reprs) ) else: # align vars on their ~ - names = [s[: s.index("~") - 1] for s in rv_reprs] - distrs = [s[s.index("~") + 2 :] for s in rv_reprs] + names = [s[: s.index("~") - 1] for s in var_reprs] + distrs = [s[s.index("~") + 2 :] for s in var_reprs] maxlen = str(max(len(x) for x in names)) - rv_reprs = [ + var_reprs = [ ("{name:>" + maxlen + "} ~ {distr}").format(name=n, distr=d) for n, d in zip(names, distrs) ] - return "\n".join(rv_reprs) + return "\n".join(var_reprs) def str_for_potential_or_deterministic( diff --git a/tests/test_printing.py b/tests/test_printing.py index 64fd21eebf..6e7a8e17d9 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -13,6 +13,8 @@ # limitations under the License. import numpy as np +from pytensor.tensor.random import normal + from pymc import Bernoulli, Censored, HalfCauchy, Mixture, StudentT from pymc.distributions import ( Dirichlet, @@ -274,3 +276,15 @@ def test_model_latex_repr_mixture_model(): "$$", ] assert [line.strip() for line in latex_repr.split("\n")] == expected + + +def test_model_repr_variables_without_monkey_patched_repr(): + """Test that model repr does not rely on individual variables having the str_repr method monkey patched.""" + x = normal(name="x") + assert not hasattr(x, "str_repr") + + model = Model() + model.register_rv(x, "x") + + str_repr = model.str_repr() + assert str_repr == "x ~ Normal(0, 1)"