Skip to content

Commit af5ea5c

Browse files
authored
implemented fix for escaping underscores in latex repr and added a un… (#7501)
* implemented fix for escaping underscores in latex repr and added a unit test * updated unit test staticmethod to include underscore in var name * add underscore escape fix to distribution repr as well as model repr, fixed testing to expect underscores in LaTeX representation to be escaped * added cleaner method using re to escape underscores, added cleaner test to assert underscores are escaped
1 parent 0cc291a commit af5ea5c

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

pymc/printing.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515

16+
import re
17+
1618
from functools import partial
1719

1820
from pytensor.compile import SharedVariable
@@ -58,6 +60,7 @@ def str_for_dist(
5860
if "latex" in formatting:
5961
if print_name is not None:
6062
print_name = r"\text{" + _latex_escape(print_name.strip("$")) + "}"
63+
print_name = _format_underscore(print_name)
6164

6265
op_name = (
6366
dist.owner.op._print_name[1]
@@ -114,6 +117,7 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool
114117
if not var_reprs:
115118
return ""
116119
if "latex" in formatting:
120+
var_reprs = [_format_underscore(x) for x in var_reprs]
117121
var_reprs = [
118122
var_repr.replace(r"\sim", r"&\sim &").strip("$")
119123
for var_repr in var_reprs
@@ -295,3 +299,10 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle):
295299
except (ModuleNotFoundError, AttributeError):
296300
# no ipython shell
297301
pass
302+
303+
304+
def _format_underscore(variable: str) -> str:
305+
"""
306+
Escapes all unescaped underscores in the variable name for LaTeX representation.
307+
"""
308+
return re.sub(r"(?<!\\)_", r"\\_", variable)

tests/test_printing.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
import numpy as np
1516

1617
from pytensor.tensor.random import normal
@@ -171,15 +172,15 @@ def setup_class(self):
171172
r"$\text{mu} \sim \operatorname{Deterministic}(f(\text{beta},~\text{alpha}))$",
172173
r"$\text{beta} \sim \operatorname{Normal}(0,~10)$",
173174
r"$\text{Z} \sim \operatorname{MultivariateNormal}(f(),~f())$",
174-
r"$\text{nb_with_p_n} \sim \operatorname{NegativeBinomial}(10,~\text{nbp})$",
175+
r"$\text{nb\_with\_p\_n} \sim \operatorname{NegativeBinomial}(10,~\text{nbp})$",
175176
r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Poisson}(5))$",
176177
r"$\text{w} \sim \operatorname{Dirichlet}(\text{<constant>})$",
177178
(
178-
r"$\text{nested_mix} \sim \operatorname{MarginalMixture}(\text{w},"
179+
r"$\text{nested\_mix} \sim \operatorname{MarginalMixture}(\text{w},"
179180
r"~\operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Poisson}(5)),"
180181
r"~\operatorname{Censored}(\operatorname{Bernoulli}(0.5),~-1,~1))$"
181182
),
182-
r"$\text{Y_obs} \sim \operatorname{Normal}(\text{mu},~\text{sigma})$",
183+
r"$\text{Y\_obs} \sim \operatorname{Normal}(\text{mu},~\text{sigma})$",
183184
r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$",
184185
r"$\text{pred} \sim \operatorname{Deterministic}(f(\text{<normal>}))",
185186
],
@@ -189,11 +190,11 @@ def setup_class(self):
189190
r"$\text{mu} \sim \operatorname{Deterministic}$",
190191
r"$\text{beta} \sim \operatorname{Normal}$",
191192
r"$\text{Z} \sim \operatorname{MultivariateNormal}$",
192-
r"$\text{nb_with_p_n} \sim \operatorname{NegativeBinomial}$",
193+
r"$\text{nb\_with\_p\_n} \sim \operatorname{NegativeBinomial}$",
193194
r"$\text{zip} \sim \operatorname{MarginalMixture}$",
194195
r"$\text{w} \sim \operatorname{Dirichlet}$",
195-
r"$\text{nested_mix} \sim \operatorname{MarginalMixture}$",
196-
r"$\text{Y_obs} \sim \operatorname{Normal}$",
196+
r"$\text{nested\_mix} \sim \operatorname{MarginalMixture}$",
197+
r"$\text{Y\_obs} \sim \operatorname{Normal}$",
197198
r"$\text{pot} \sim \operatorname{Potential}$",
198199
r"$\text{pred} \sim \operatorname{Deterministic}",
199200
],
@@ -256,7 +257,7 @@ def test_model_latex_repr_three_levels_model():
256257
"$$",
257258
"\\begin{array}{rcl}",
258259
"\\text{mu} &\\sim & \\operatorname{Normal}(0,~5)\\\\\\text{sigma} &\\sim & "
259-
"\\operatorname{HalfCauchy}(0,~2.5)\\\\\\text{censored_normal} &\\sim & "
260+
"\\operatorname{HalfCauchy}(0,~2.5)\\\\\\text{censored\\_normal} &\\sim & "
260261
"\\operatorname{Censored}(\\operatorname{Normal}(\\text{mu},~\\text{sigma}),~-2,~2)",
261262
"\\end{array}",
262263
"$$",
@@ -316,3 +317,22 @@ def random(rng, mu, size):
316317

317318
str_repr = model.str_repr(include_params=False)
318319
assert str_repr == "\n".join(["x ~ CustomDistNormal", "y ~ CustomRandomNormal"])
320+
321+
322+
class TestLatexRepr:
323+
@staticmethod
324+
def simple_model() -> Model:
325+
with Model() as simple_model:
326+
error = HalfNormal("error", 0.5)
327+
alpha_a = Normal("alpha_a", 0, 1)
328+
Normal("y", alpha_a, error)
329+
return simple_model
330+
331+
def test_latex_escaped_underscore(self):
332+
"""
333+
Ensures that all underscores in model variable names are properly escaped for LaTeX representation
334+
"""
335+
model = self.simple_model()
336+
model_str = model.str_repr(formatting="latex")
337+
assert "\\_" in model_str
338+
assert "_" not in model_str.replace("\\_", "")

0 commit comments

Comments
 (0)