Skip to content

Commit e76ccb4

Browse files
committed
Deprecate pytensor_config
1 parent af8cf19 commit e76ccb4

File tree

4 files changed

+10
-14
lines changed

4 files changed

+10
-14
lines changed

pymc/model/core.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,13 @@ def __new__(cls, *args, **kwargs):
496496
instance._parent = kwargs.get("model")
497497
else:
498498
instance._parent = cls.get_context(error_if_none=False)
499-
instance._pytensor_config = kwargs.get("pytensor_config", {})
499+
pytensor_config = kwargs.get("pytensor_config", {})
500+
if pytensor_config:
501+
warnings.warn(
502+
"pytensor_config is deprecated. Use pytensor.config or pytensor.config.change_flags context manager instead.",
503+
FutureWarning,
504+
)
505+
instance._pytensor_config = pytensor_config
500506
return instance
501507

502508
@staticmethod

tests/model/test_core.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1102,7 +1102,9 @@ def test_compile_fn():
11021102

11031103
def test_model_pytensor_config():
11041104
assert pytensor.config.mode != "JAX"
1105-
with pm.Model(pytensor_config=dict(mode="JAX")) as model:
1105+
with pytest.warns(FutureWarning, match="pytensor_config is deprecated"):
1106+
m = pm.Model(pytensor_config=dict(mode="JAX"))
1107+
with m:
11061108
assert pytensor.config.mode == "JAX"
11071109
assert pytensor.config.mode != "JAX"
11081110

tests/sampling/test_mcmc.py

-8
Original file line numberDiff line numberDiff line change
@@ -797,14 +797,6 @@ def test_step_vars_in_model(self):
797797
class TestType:
798798
samplers = (Metropolis, Slice, HamiltonianMC, NUTS)
799799

800-
def setup_method(self):
801-
# save PyTensor config object
802-
self.pytensor_config = copy(pytensor.config)
803-
804-
def teardown_method(self):
805-
# restore pytensor config
806-
pytensor.config = self.pytensor_config
807-
808800
@pytensor.config.change_flags({"floatX": "float64", "warn_float64": "ignore"})
809801
def test_float64(self):
810802
with pm.Model() as model:

tests/variational/test_approximations.py

-4
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,6 @@ def test_scale_cost_to_minibatch_works(aux_total_size):
8484
y_obs = np.array([1.6, 1.4])
8585
beta = len(y_obs) / float(aux_total_size)
8686

87-
# TODO: pytensor_config
88-
# with pm.Model(pytensor_config=dict(floatX='float64')):
89-
# did not not work as expected
90-
# there were some numeric problems, so float64 is forced
9187
with pytensor.config.change_flags(floatX="float64", warn_float64="ignore"):
9288
assert pytensor.config.floatX == "float64"
9389
assert pytensor.config.warn_float64 == "ignore"

0 commit comments

Comments
 (0)