|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -import itertools |
16 |
| - |
17 | 15 | from typing import Union
|
18 | 16 |
|
19 | 17 | from pytensor.compile import SharedVariable
|
@@ -98,36 +96,46 @@ def str_for_dist(
|
98 | 96 | def str_for_model(model: Model, formatting: str = "plain", include_params: bool = True) -> str:
|
99 | 97 | """Make a human-readable string representation of Model, listing all random variables
|
100 | 98 | and their distributions, optionally including parameter values."""
|
101 |
| - all_rv = itertools.chain(model.unobserved_RVs, model.observed_RVs, model.potentials) |
102 | 99 |
|
103 |
| - rv_reprs = [rv.str_repr(formatting=formatting, include_params=include_params) for rv in all_rv] |
104 |
| - rv_reprs = [rv_repr for rv_repr in rv_reprs if "TransformedDistribution()" not in rv_repr] |
| 100 | + kwargs = dict(formatting=formatting, include_params=include_params) |
| 101 | + free_rv_reprs = [str_for_dist(dist, **kwargs) for dist in model.free_RVs] |
| 102 | + observed_rv_reprs = [str_for_dist(rv, **kwargs) for rv in model.observed_RVs] |
| 103 | + det_reprs = [ |
| 104 | + str_for_potential_or_deterministic(dist, **kwargs, dist_name="Deterministic") |
| 105 | + for dist in model.deterministics |
| 106 | + ] |
| 107 | + potential_reprs = [ |
| 108 | + str_for_potential_or_deterministic(pot, **kwargs, dist_name="Potential") |
| 109 | + for pot in model.potentials |
| 110 | + ] |
| 111 | + |
| 112 | + var_reprs = free_rv_reprs + det_reprs + observed_rv_reprs + potential_reprs |
105 | 113 |
|
106 |
| - if not rv_reprs: |
| 114 | + if not var_reprs: |
107 | 115 | return ""
|
108 | 116 | if "latex" in formatting:
|
109 |
| - rv_reprs = [ |
110 |
| - rv_repr.replace(r"\sim", r"&\sim &").strip("$") |
111 |
| - for rv_repr in rv_reprs |
112 |
| - if rv_repr is not None |
| 117 | + var_reprs = [ |
| 118 | + var_repr.replace(r"\sim", r"&\sim &").strip("$") |
| 119 | + for var_repr in var_reprs |
| 120 | + if var_repr is not None |
113 | 121 | ]
|
114 | 122 | return r"""$$
|
115 | 123 | \begin{{array}}{{rcl}}
|
116 | 124 | {}
|
117 | 125 | \end{{array}}
|
118 | 126 | $$""".format(
|
119 |
| - "\\\\".join(rv_reprs) |
| 127 | + "\\\\".join(var_reprs) |
120 | 128 | )
|
121 | 129 | else:
|
122 | 130 | # align vars on their ~
|
123 |
| - names = [s[: s.index("~") - 1] for s in rv_reprs] |
124 |
| - distrs = [s[s.index("~") + 2 :] for s in rv_reprs] |
| 131 | + names = [s[: s.index("~") - 1] for s in var_reprs] |
| 132 | + distrs = [s[s.index("~") + 2 :] for s in var_reprs] |
125 | 133 | maxlen = str(max(len(x) for x in names))
|
126 |
| - rv_reprs = [ |
| 134 | + var_reprs = [ |
127 | 135 | ("{name:>" + maxlen + "} ~ {distr}").format(name=n, distr=d)
|
128 | 136 | for n, d in zip(names, distrs)
|
129 | 137 | ]
|
130 |
| - return "\n".join(rv_reprs) |
| 138 | + return "\n".join(var_reprs) |
131 | 139 |
|
132 | 140 |
|
133 | 141 | def str_for_potential_or_deterministic(
|
|
0 commit comments