Skip to content

Commit 049c5f8

Browse files
Prevent Model from turning on test value computations
1 parent 80d189f commit 049c5f8

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

pymc3/model.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -809,10 +809,7 @@ def __new__(cls, *args, **kwargs):
809809
instance._parent = kwargs.get("model")
810810
else:
811811
instance._parent = cls.get_context(error_if_none=False)
812-
aesara_config = kwargs.get("aesara_config", None)
813-
if aesara_config is None or "compute_test_value" not in aesara_config:
814-
aesara_config = {"compute_test_value": "ignore"}
815-
instance._aesara_config = aesara_config
812+
instance._aesara_config = kwargs.get("aesara_config", {})
816813
return instance
817814

818815
def __init__(self, name="", model=None, aesara_config=None, coords=None, check_bounds=True):
@@ -1007,7 +1004,20 @@ def independent_vars(self):
10071004
@property
10081005
def test_point(self):
10091006
"""Test point used to check that the model doesn't generate errors"""
1010-
return Point(((var, var.tag.test_value) for var in self.vars), model=self)
1007+
points = []
1008+
for var in self.free_RVs:
1009+
var_value = getattr(var.tag, "test_value", None)
1010+
1011+
if var_value is None:
1012+
try:
1013+
var_value = var.eval()
1014+
var.tag.test_value = var_value
1015+
except Exception:
1016+
raise Exception(f"Couldn't generate an initial value for {var}")
1017+
1018+
points.append((getattr(var.tag, "value_var", var), var_value))
1019+
1020+
return Point(points, model=self)
10111021

10121022
@property
10131023
def disc_vars(self):
@@ -1594,11 +1604,11 @@ def make_obs_var(rv_var: TensorVariable, data: Union[np.ndarray]) -> TensorVaria
15941604
else:
15951605
new_size = data.shape
15961606

1597-
test_value = getattr(rv_var.tag, "test_value", None)
1598-
15991607
rv_var = change_rv_size(rv_var, new_size)
16001608

16011609
if aesara.config.compute_test_value != "off":
1610+
test_value = getattr(rv_var.tag, "test_value", None)
1611+
16021612
if test_value is not None:
16031613
# We try to reuse the old test value
16041614
rv_var.tag.test_value = np.broadcast_to(test_value, rv_var.tag.test_value.shape)

0 commit comments

Comments
 (0)