Skip to content

Set theano config in model context #2103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 8, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from theano import theano, tensor as tt
from theano.tensor.var import TensorVariable

from pymc3.theanof import set_theano_conf
import pymc3 as pm
from pymc3.math import flatten_list
from .memoize import memoize
Expand Down Expand Up @@ -108,10 +109,16 @@ class Context(object):

def __enter__(self):
type(self).get_contexts().append(self)
# self._theano_config is set in Model.__new__
if hasattr(self, '_theano_config'):
self._old_theano_config = set_theano_conf(self._theano_config)
return self

def __exit__(self, typ, value, traceback):
type(self).get_contexts().pop()
# self._theano_config is set in Model.__new__
if hasattr(self, '_old_theano_config'):
set_theano_conf(self._old_theano_config)

@classmethod
def get_contexts(cls):
Expand Down Expand Up @@ -301,6 +308,11 @@ class Model(six.with_metaclass(InitContextMeta, Context, Factor)):
will be passed to the parent instance. So that 'nested' model
contributes to the variables and likelihood factors of
parent model.
theano_config : dict, default=None
A dictionary of theano config values that should be set
temporarily in the model context. See the documentation
of theano for a complete list. Set `compute_test_value` to
`raise` if it is None.

Examples
--------
Expand Down Expand Up @@ -367,9 +379,13 @@ def __new__(cls, *args, **kwargs):
instance._parent = cls.get_contexts()[-1]
else:
instance._parent = None
theano_config = kwargs.get('theano_config', None)
if theano_config is None or 'compute_test_value' not in theano_config:
theano_config = {'compute_test_value': 'raise'}
instance._theano_config = theano_config
return instance

def __init__(self, name='', model=None):
def __init__(self, name='', model=None, theano_config=None):
self.name = name
if self.parent is not None:
self.named_vars = treedict(parent=self.parent.named_vars)
Expand Down Expand Up @@ -1032,7 +1048,3 @@ def all_continuous(vars):
return False
else:
return True

# theano stuff
theano.config.warn.sum_div_dimshuffle_bug = False
theano.config.compute_test_value = 'raise'
9 changes: 9 additions & 0 deletions pymc3/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import theano
import pytest


@pytest.fixture(scope="session", autouse=True)
def theano_config():
config = theano.configparser.change_flags(compute_test_value='raise')
with config:
yield
25 changes: 23 additions & 2 deletions pymc3/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def gen2():
yield np.ones((20, 100)) * i
i += 1


class NewModel(pm.Model):
def __init__(self, name='', model=None):
super(NewModel, self).__init__(name, model)
Expand Down Expand Up @@ -133,13 +134,15 @@ def test_model_root(self):
with pm.Model() as sub:
assert model is sub.root


class TestObserved(object):
def test_observed_rv_fail(self):
with pytest.raises(TypeError):
with pm.Model() as model:
with pm.Model():
x = Normal('x')
Normal('n', observed=x)


class TestScaling(object):
def test_density_scaling(self):
with pm.Model() as model1:
Expand Down Expand Up @@ -171,7 +174,8 @@ def true_dens():

for i in range(10):
_1, _2, _t = p1(), p2(), next(t)
np.testing.assert_almost_equal(_1, _t, decimal=select_by_precision(float64=7, float32=2)) # Value O(-50,000)
decimals = select_by_precision(float64=7, float32=2)
np.testing.assert_almost_equal(_1, _t, decimal=decimals) # Value O(-50,000)
np.testing.assert_almost_equal(_1, _2)
# Done

Expand All @@ -192,3 +196,20 @@ def test_gradient_with_scaling(self):
g1 = grad1(1)
g2 = grad2(1)
np.testing.assert_almost_equal(g1, g2)


class TestTheanoConfig(object):
def test_set_testval_raise(self):
with theano.configparser.change_flags(compute_test_value='off'):
with pm.Model():
assert theano.config.compute_test_value == 'raise'
assert theano.config.compute_test_value == 'off'

def test_nested(self):
with theano.configparser.change_flags(compute_test_value='off'):
with pm.Model(theano_config={'compute_test_value': 'ignore'}):
assert theano.config.compute_test_value == 'ignore'
with pm.Model(theano_config={'compute_test_value': 'warn'}):
assert theano.config.compute_test_value == 'warn'
assert theano.config.compute_test_value == 'ignore'
assert theano.config.compute_test_value == 'off'
6 changes: 3 additions & 3 deletions pymc3/tests/test_special_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def test_functions():
psi = function([x], ps.psi(x))
function([x, p], ps.multigammaln(x, p))
for x in xvals:
yield check_vals, gammaln, ss.gammaln, x
check_vals(gammaln, ss.gammaln, x)
for x in xvals[1:]:
yield check_vals, psi, ss.psi, x
check_vals(psi, ss.psi, x)

"""
scipy.special.multigammaln gives bad values if you pass a non scalar to a
Expand Down Expand Up @@ -52,7 +52,7 @@ def ssmultigammaln(a, b):

for p in [0, 1, 2, 3, 4, 100]:
for x in xvals:
yield check_vals, multigammaln, ssmultigammaln, x, p
check_vals(multigammaln, ssmultigammaln, x, p)


def check_vals(fn1, fn2, *args):
Expand Down
10 changes: 5 additions & 5 deletions pymc3/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def teardown_class(self):

def test_sample_exact(self):
for step_method in self.master_samples:
yield self.check_trace, step_method
self.check_trace(step_method)

def check_trace(self, step_method):
"""Tests whether the trace for step methods is exactly the same as on master.
Expand Down Expand Up @@ -196,7 +196,7 @@ def test_step_continuous(self):
)
for step in steps:
trace = sample(8000, step=step, start=start, model=model, random_seed=1)
yield self.check_stat, check, trace, step.__class__.__name__
self.check_stat(check, trace, step.__class__.__name__)

def test_step_discrete(self):
if theano.config.floatX == "float32":
Expand All @@ -211,7 +211,7 @@ def test_step_discrete(self):
)
for step in steps:
trace = sample(20000, step=step, start=start, model=model, random_seed=1)
yield self.check_stat, check, trace, step.__class__.__name__
self.check_stat(check, trace, step.__class__.__name__)

def test_step_categorical(self):
start, model, (mu, C) = simple_categorical()
Expand All @@ -225,7 +225,7 @@ def test_step_categorical(self):
)
for step in steps:
trace = sample(8000, step=step, start=start, model=model, random_seed=1)
yield self.check_stat, check, trace, step.__class__.__name__
self.check_stat(check, trace, step.__class__.__name__)

def test_step_elliptical_slice(self):
start, model, (K, L, mu, std, noise) = mv_prior_simple()
Expand All @@ -239,7 +239,7 @@ def test_step_elliptical_slice(self):
)
for step in steps:
trace = sample(5000, step=step, start=start, model=model, random_seed=1)
yield self.check_stat, check, trace, step.__class__.__name__
self.check_stat(check, trace, step.__class__.__name__)


class TestMetropolisProposal(object):
Expand Down
25 changes: 24 additions & 1 deletion pymc3/tests/test_theanof.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pickle
import itertools
import collections
import numpy as np
from theano import theano
from pymc3.theanof import GeneratorOp, generator, tt_rng, floatX
from pymc3.theanof import GeneratorOp, generator, tt_rng, floatX, set_theano_conf
from pymc3.data import DataSampler, GeneratorAdapter
import pytest

Expand Down Expand Up @@ -99,3 +100,25 @@ def test_gen_cloning_with_shape_change(self):
shared = theano.shared(data)
res2 = theano.clone(res, {gen: shared**2})
assert res2.eval().shape == (1000,)


class TestSetTheanoConfig(object):
def test_invalid_key(self):
with pytest.raises(ValueError) as e:
set_theano_conf({'bad_key': True})
e.match('Unknown')

def test_restore_when_bad_key(self):
with theano.configparser.change_flags(compute_test_value='off'):
with pytest.raises(ValueError):
conf = collections.OrderedDict(
[('compute_test_value', 'raise'), ('bad_key', True)])
set_theano_conf(conf)
assert theano.config.compute_test_value == 'off'

def test_restore(self):
with theano.configparser.change_flags(compute_test_value='off'):
conf = set_theano_conf({'compute_test_value': 'raise'})
assert conf == {'compute_test_value': 'off'}
conf = set_theano_conf(conf)
assert conf == {'compute_test_value': 'raise'}
28 changes: 28 additions & 0 deletions pymc3/theanof.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,31 @@ def grad(self, args, g_outs):

def floatX_array(x):
return floatX(np.array(x))


def set_theano_conf(values):
"""Change the theano configuration and return old values.

This is similar to `theano.configparser.change_flags`, but it
returns the original values in a pickleable form.
"""
variables = {}
unknown = set(values.keys())
for variable in theano.configparser._config_var_list:
if variable.fullname in values:
variables[variable.fullname] = variable
unknown.remove(variable.fullname)
if len(unknown) > 0:
raise ValueError("Unknown theano config settings: %s" % unknown)

old = {}
for name, variable in variables.items():
old_value = variable.__get__(True, None)
try:
variable.__set__(None, values[name])
except Exception:
for key, old_value in old.items():
variables[key].__set__(None, old_value)
raise
old[name] = old_value
return old