Skip to content

Commit 0a729b9

Browse files
committed
Refactor TestStrAndLatexRepr to allow reusing test functionality with separate models
1 parent b0d1066 commit 0a729b9

File tree

1 file changed

+28
-27
lines changed

1 file changed

+28
-27
lines changed

pymc/tests/test_printing.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,33 @@
1717
from pymc.model import Deterministic, Model, Potential
1818

1919

20-
# TODO: This test is a bit too monolithic
21-
class TestStrAndLatexRepr:
20+
class BaseTestStrAndLatexRepr:
21+
def test__repr_latex_(self):
22+
for distribution, tex in zip(self.distributions, self.expected[("latex", True)]):
23+
assert distribution._repr_latex_() == tex
24+
25+
model_tex = self.model._repr_latex_()
26+
27+
# make sure each variable is in the model
28+
for tex in self.expected[("latex", True)]:
29+
for segment in tex.strip("$").split(r"\sim"):
30+
assert segment in model_tex
31+
32+
def test_str_repr(self):
33+
for str_format in self.formats:
34+
for dist, text in zip(self.distributions, self.expected[str_format]):
35+
assert dist.str_repr(*str_format) == text
36+
37+
model_text = self.model.str_repr(*str_format)
38+
for text in self.expected[str_format]:
39+
if str_format[0] == "latex":
40+
for segment in text.strip("$").split(r"\sim"):
41+
assert segment in model_text
42+
else:
43+
assert text in model_text
44+
45+
46+
class TestMonolith(BaseTestStrAndLatexRepr):
2247
def setup_class(self):
2348
# True parameter values
2449
alpha, sigma = 1, 1
@@ -90,7 +115,7 @@ def setup_class(self):
90115

91116
self.distributions = [alpha, sigma, mu, b, Z, nb2, zip, w, nested_mix, Y_obs, pot]
92117
self.deterministics_or_potentials = [mu, pot]
93-
# tuples of (formatting, include_params
118+
# tuples of (formatting, include_params)
94119
self.formats = [("plain", True), ("plain", False), ("latex", True), ("latex", False)]
95120
self.expected = {
96121
("plain", True): [
@@ -155,30 +180,6 @@ def setup_class(self):
155180
],
156181
}
157182

158-
def test__repr_latex_(self):
159-
for distribution, tex in zip(self.distributions, self.expected[("latex", True)]):
160-
assert distribution._repr_latex_() == tex
161-
162-
model_tex = self.model._repr_latex_()
163-
164-
# make sure each variable is in the model
165-
for tex in self.expected[("latex", True)]:
166-
for segment in tex.strip("$").split(r"\sim"):
167-
assert segment in model_tex
168-
169-
def test_str_repr(self):
170-
for str_format in self.formats:
171-
for dist, text in zip(self.distributions, self.expected[str_format]):
172-
assert dist.str_repr(*str_format) == text
173-
174-
model_text = self.model.str_repr(*str_format)
175-
for text in self.expected[str_format]:
176-
if str_format[0] == "latex":
177-
for segment in text.strip("$").split(r"\sim"):
178-
assert segment in model_text
179-
else:
180-
assert text in model_text
181-
182183

183184
def test_model_latex_repr_three_levels_model():
184185
with Model() as censored_model:

0 commit comments

Comments
 (0)