Skip to content

Commit fd5a6cc

Browse files
committed
Fix latex representation for SharedVariable inputs
1 parent 0a729b9 commit fd5a6cc

File tree

2 files changed

+59
-10
lines changed

2 files changed

+59
-10
lines changed

pymc/printing.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616

1717
from typing import Union
1818

19-
from aesara.graph.basic import walk
19+
from aesara.compile import SharedVariable
20+
from aesara.graph.basic import Constant, walk
2021
from aesara.tensor.basic import TensorVariable, Variable
2122
from aesara.tensor.elemwise import DimShuffle
2223
from aesara.tensor.random.basic import RandomVariable
2324
from aesara.tensor.random.var import (
2425
RandomGeneratorSharedVariable,
2526
RandomStateSharedVariable,
2627
)
27-
from aesara.tensor.var import TensorConstant
2828

2929
from pymc.model import Model
3030

@@ -163,7 +163,7 @@ def _is_potential_or_determinstic(var: Variable) -> bool:
163163
# in case other code overrides str_repr, fallback
164164
return False
165165

166-
if isinstance(var, TensorConstant):
166+
if isinstance(var, (Constant, SharedVariable)):
167167
return _str_for_constant(var, formatting)
168168
elif isinstance(
169169
var.owner.op, (RandomVariable, SymbolicRandomVariable)
@@ -189,15 +189,22 @@ def _str_for_input_rv(var: Variable, formatting: str) -> str:
189189
return _str
190190

191191

192-
def _str_for_constant(var: TensorConstant, formatting: str) -> str:
193-
if len(var.data.shape) == 0:
194-
return f"{var.data:.3g}"
195-
elif len(var.data.shape) == 1 and var.data.shape[0] == 1:
196-
return f"{var.data[0]:.3g}"
192+
def _str_for_constant(var: Union[Constant, SharedVariable], formatting: str) -> str:
193+
if isinstance(var, Constant):
194+
var_data = var.data
195+
var_type = "constant"
196+
else:
197+
var_data = var.get_value()
198+
var_type = "shared"
199+
200+
if len(var_data.shape) == 0:
201+
return f"{var_data:.3g}"
202+
elif len(var_data.shape) == 1 and var_data.shape[0] == 1:
203+
return f"{var_data[0]:.3g}"
197204
elif "latex" in formatting:
198-
return r"\text{<constant>}"
205+
return rf"\text{{<{var_type}>}}"
199206
else:
200-
return r"<constant>"
207+
return rf"<{var_type}>"
201208

202209

203210
def _str_for_expression(var: Variable, formatting: str) -> str:

pymc/tests/test_printing.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,48 @@ def setup_class(self):
181181
}
182182

183183

184+
class TestData(BaseTestStrAndLatexRepr):
185+
def setup_class(self):
186+
with Model() as self.model:
187+
import pymc as pm
188+
189+
with pm.Model() as model:
190+
a = pm.Normal("a", pm.MutableData("a_data", (2,)))
191+
b = pm.Normal("b", pm.MutableData("b_data", (2, 3)))
192+
c = pm.Normal("c", pm.ConstantData("c_data", (2,)))
193+
d = pm.Normal("d", pm.ConstantData("d_data", (2, 3)))
194+
195+
self.distributions = [a, b, c, d]
196+
# tuples of (formatting, include_params)
197+
self.formats = [("plain", True), ("plain", False), ("latex", True), ("latex", False)]
198+
self.expected = {
199+
("plain", True): [
200+
r"a ~ N(2, 1)",
201+
r"b ~ N(<shared>, 1)",
202+
r"c ~ N(2, 1)",
203+
r"d ~ N(<constant>, 1)",
204+
],
205+
("plain", False): [
206+
r"a ~ N",
207+
r"b ~ N",
208+
r"c ~ N",
209+
r"d ~ N",
210+
],
211+
("latex", True): [
212+
r"$\text{a} \sim \operatorname{N}(2,~1)$",
213+
r"$\text{b} \sim \operatorname{N}(\text{<shared>},~1)$",
214+
r"$\text{c} \sim \operatorname{N}(2,~1)$",
215+
r"$\text{d} \sim \operatorname{N}(\text{<constant>},~1)$",
216+
],
217+
("latex", False): [
218+
r"$\text{a} \sim \operatorname{N}$",
219+
r"$\text{b} \sim \operatorname{N}$",
220+
r"$\text{c} \sim \operatorname{N}$",
221+
r"$\text{d} \sim \operatorname{N}$",
222+
],
223+
}
224+
225+
184226
def test_model_latex_repr_three_levels_model():
185227
with Model() as censored_model:
186228
mu = Normal("mu", 0.0, 5.0)

0 commit comments

Comments
 (0)