Skip to content

Commit 6d6d6bd

Browse files
committed
Set theano config in model context
1 parent 56a5f1a commit 6d6d6bd

File tree

5 files changed

+103
-8
lines changed

5 files changed

+103
-8
lines changed

pymc3/model.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from theano import theano, tensor as tt
99
from theano.tensor.var import TensorVariable
1010

11+
from pymc3.theanof import set_theano_conf
1112
import pymc3 as pm
1213
from pymc3.math import flatten_list
1314
from .memoize import memoize
@@ -108,10 +109,16 @@ class Context(object):
108109

109110
def __enter__(self):
110111
type(self).get_contexts().append(self)
112+
# self._theano_config is set in Model.__new__
113+
if hasattr(self, '_theano_config'):
114+
self._old_theano_config = set_theano_conf(self._theano_config)
111115
return self
112116

113117
def __exit__(self, typ, value, traceback):
114118
type(self).get_contexts().pop()
119+
# self._theano_config is set in Model.__new__
120+
if hasattr(self, '_old_theano_config'):
121+
set_theano_conf(self._old_theano_config)
115122

116123
@classmethod
117124
def get_contexts(cls):
@@ -301,6 +308,11 @@ class Model(six.with_metaclass(InitContextMeta, Context, Factor)):
301308
will be passed to the parent instance. So that 'nested' model
302309
contributes to the variables and likelihood factors of
303310
parent model.
311+
theano_config : dict, default=None
312+
A dictionary of theano config values that should be set
313+
temporarily in the model context. See the documentation
314+
of theano for a complete list. Set `compute_test_value` to
315+
`raise` if it is None.
304316
305317
Examples
306318
--------
@@ -367,9 +379,13 @@ def __new__(cls, *args, **kwargs):
367379
instance._parent = cls.get_contexts()[-1]
368380
else:
369381
instance._parent = None
382+
theano_config = kwargs.get('theano_config', None)
383+
if theano_config is None or 'compute_test_value' not in theano_config:
384+
theano_config = {'compute_test_value': 'raise'}
385+
instance._theano_config = theano_config
370386
return instance
371387

372-
def __init__(self, name='', model=None):
388+
def __init__(self, name='', model=None, theano_config=None):
373389
self.name = name
374390
if self.parent is not None:
375391
self.named_vars = treedict(parent=self.parent.named_vars)
@@ -1032,7 +1048,3 @@ def all_continuous(vars):
10321048
return False
10331049
else:
10341050
return True
1035-
1036-
# theano stuff
1037-
theano.config.warn.sum_div_dimshuffle_bug = False
1038-
theano.config.compute_test_value = 'raise'

pymc3/tests/conftest.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import theano
2+
3+
config = theano.configparser.change_flags(compute_test_value='raise')
4+
5+
6+
def pytest_sessionstart(session):
7+
config.__enter__()
8+
9+
10+
def pytest_sessionfinish(session, exitstatus):
11+
config.__exit__()

pymc3/tests/test_model.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def gen2():
2222
yield np.ones((20, 100)) * i
2323
i += 1
2424

25+
2526
class NewModel(pm.Model):
2627
def __init__(self, name='', model=None):
2728
super(NewModel, self).__init__(name, model)
@@ -133,13 +134,15 @@ def test_model_root(self):
133134
with pm.Model() as sub:
134135
assert model is sub.root
135136

137+
136138
class TestObserved(object):
137139
def test_observed_rv_fail(self):
138140
with pytest.raises(TypeError):
139-
with pm.Model() as model:
141+
with pm.Model():
140142
x = Normal('x')
141143
Normal('n', observed=x)
142144

145+
143146
class TestScaling(object):
144147
def test_density_scaling(self):
145148
with pm.Model() as model1:
@@ -171,7 +174,8 @@ def true_dens():
171174

172175
for i in range(10):
173176
_1, _2, _t = p1(), p2(), next(t)
174-
np.testing.assert_almost_equal(_1, _t, decimal=select_by_precision(float64=7, float32=2)) # Value O(-50,000)
177+
decimals = select_by_precision(float64=7, float32=2)
178+
np.testing.assert_almost_equal(_1, _t, decimal=decimals) # Value O(-50,000)
175179
np.testing.assert_almost_equal(_1, _2)
176180
# Done
177181

@@ -192,3 +196,20 @@ def test_gradient_with_scaling(self):
192196
g1 = grad1(1)
193197
g2 = grad2(1)
194198
np.testing.assert_almost_equal(g1, g2)
199+
200+
201+
class TestTheanoConfig(object):
202+
def test_set_testval_raise(self):
203+
with theano.configparser.change_flags(compute_test_value='off'):
204+
with pm.Model():
205+
assert theano.config.compute_test_value == 'raise'
206+
assert theano.config.compute_test_value == 'off'
207+
208+
def test_nested(self):
209+
with theano.configparser.change_flags(compute_test_value='off'):
210+
with pm.Model(theano_config={'compute_test_value': 'ignore'}):
211+
assert theano.config.compute_test_value == 'ignore'
212+
with pm.Model(theano_config={'compute_test_value': 'warn'}):
213+
assert theano.config.compute_test_value == 'warn'
214+
assert theano.config.compute_test_value == 'ignore'
215+
assert theano.config.compute_test_value == 'off'

pymc3/tests/test_theanof.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import pickle
22
import itertools
3+
import collections
34
import numpy as np
45
from theano import theano
5-
from pymc3.theanof import GeneratorOp, generator, tt_rng, floatX
6+
from pymc3.theanof import GeneratorOp, generator, tt_rng, floatX, set_theano_conf
67
from pymc3.data import DataSampler, GeneratorAdapter
78
import pytest
89

@@ -99,3 +100,25 @@ def test_gen_cloning_with_shape_change(self):
99100
shared = theano.shared(data)
100101
res2 = theano.clone(res, {gen: shared**2})
101102
assert res2.eval().shape == (1000,)
103+
104+
105+
class TestSetTheanoConfig(object):
106+
def test_invalid_key(self):
107+
with pytest.raises(ValueError) as e:
108+
set_theano_conf({'bad_key': True})
109+
e.match('Unknown')
110+
111+
def test_restore_when_bad_key(self):
112+
with theano.configparser.change_flags(compute_test_value='off'):
113+
with pytest.raises(ValueError):
114+
conf = collections.OrderedDict(
115+
[('compute_test_value', 'raise'), ('bad_key', True)])
116+
set_theano_conf(conf)
117+
assert theano.config.compute_test_value == 'off'
118+
119+
def test_restore(self):
120+
with theano.configparser.change_flags(compute_test_value='off'):
121+
conf = set_theano_conf({'compute_test_value': 'raise'})
122+
assert conf == {'compute_test_value': 'off'}
123+
conf = set_theano_conf(conf)
124+
assert conf == {'compute_test_value': 'raise'}

pymc3/theanof.py

+28
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,31 @@ def grad(self, args, g_outs):
427427

428428
def floatX_array(x):
429429
return floatX(np.array(x))
430+
431+
432+
def set_theano_conf(values):
433+
"""Change the theano configuration and return old values.
434+
435+
This is similar to `theano.configparser.change_flags`, but it
436+
returns the original values in a pickleable form.
437+
"""
438+
variables = {}
439+
unknown = set(values.keys())
440+
for variable in theano.configparser._config_var_list:
441+
if variable.fullname in values:
442+
variables[variable.fullname] = variable
443+
unknown.remove(variable.fullname)
444+
if len(unknown) > 0:
445+
raise ValueError("Unknown theano config settings: %s" % unknown)
446+
447+
old = {}
448+
for name, variable in variables.items():
449+
old_value = variable.__get__(True, None)
450+
try:
451+
variable.__set__(None, values[name])
452+
except Exception:
453+
for key, old_value in old.items():
454+
variables[key].__set__(None, old_value)
455+
raise
456+
old[name] = old_value
457+
return old

0 commit comments

Comments
 (0)