Skip to content

WIP: Implement opvi #1694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 60 commits into from
Mar 15, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
cbc5568
migrate useful functions from previous PR
ferrine Jan 20, 2017
d69a96e
opvi draft
ferrine Jan 20, 2017
973ca6e
made some test work
ferrine Jan 22, 2017
aca56f6
refactored approximation to support aevb (without test)
ferrine Jan 22, 2017
1df1104
refactor opvi
ferrine Jan 22, 2017
a07afed
change log_q_local computation
ferrine Jan 22, 2017
33f8c80
add full rank approximation
ferrine Jan 22, 2017
df9adfc
add more_params argument to ObjectiveFunction.updates (aevb case)
ferrine Jan 22, 2017
3780a4b
refactor density computation in full rank approximation
ferrine Jan 22, 2017
bf6c234
typo: cast dict values to list
ferrine Jan 22, 2017
f3c6be6
typo: cast dict values to list
ferrine Jan 22, 2017
906dc5a
typo: undefined T in dist_math
ferrine Jan 22, 2017
49c96e1
refactor gradient scaling as suggested in approximateinference.org/ac…
ferrine Jan 22, 2017
85151ee
implement Langevin-Stein (LS) operator
ferrine Jan 22, 2017
a29709c
fix docstring
ferrine Jan 22, 2017
3363c98
add blank line in docs
ferrine Jan 23, 2017
81f1217
refactor ObjectiveFunction
ferrine Jan 23, 2017
d92afd2
add not working LS Op test
ferrine Jan 24, 2017
df54d77
experiments with not working LS Op
ferrine Jan 24, 2017
a0cd3df
change activations
ferrine Jan 25, 2017
2096601
refactor networks
ferrine Feb 10, 2017
adfb0e4
add step_function
ferrine Feb 10, 2017
8febbc1
remove Langevin Stein, done refactoring
ferrine Feb 13, 2017
1ffab9c
remove Langevin Stein, done refactoring
ferrine Feb 13, 2017
bff2d58
change optimizers
ferrine Feb 13, 2017
feab53c
refactor init params
ferrine Feb 13, 2017
76b1bf1
implement tests
ferrine Feb 13, 2017
80dee6d
implement Inference
ferrine Feb 13, 2017
1eda78b
code style
ferrine Feb 13, 2017
fcdeb1b
test fix
ferrine Feb 13, 2017
0e02929
add minibatch test (fails now)
ferrine Feb 15, 2017
dc2578d
add more tests for minibatch training
ferrine Feb 21, 2017
4638532
add logdet to FullRank approximation
ferrine Feb 21, 2017
4b55969
add conversion of arrays to floatX
ferrine Feb 21, 2017
4237f07
tiny changes
ferrine Feb 21, 2017
87b2cc5
change number of iterations
ferrine Feb 22, 2017
96aa930
fix test and pylint check
ferrine Feb 22, 2017
76c5fd7
memoize functions in Objective function
ferrine Feb 22, 2017
a59eebc
Optimize code a lot
ferrine Feb 25, 2017
9e8bf41
a bit more efficient pickling
ferrine Feb 26, 2017
afa8af2
add docs
ferrine Feb 26, 2017
4fcef3c
Add MeanField -> FullRank parameter transfer
ferrine Feb 26, 2017
6bfc243
refactor MeanField and FullRank a bit
ferrine Feb 26, 2017
3049d8f
fix FullRank bug with shapes in random
ferrine Feb 26, 2017
a7a31de
refactor Model.flatten (CC @taku-y)
ferrine Mar 2, 2017
32487c4
add `approximate` to inference
ferrine Mar 3, 2017
c415647
rename approximate->fit
ferrine Mar 6, 2017
3ccc7b8
change abbreviations
ferrine Mar 6, 2017
3c1b805
Fix bug with scaling input variable in aevb
ferrine Mar 8, 2017
6a7d84f
fix theano bottleneck in graph
ferrine Mar 10, 2017
c695990
more efficient scaling for local vars
ferrine Mar 10, 2017
92f0ca7
fix typo in local Q
ferrine Mar 13, 2017
b6ac76e
add aevb test
ferrine Mar 13, 2017
7a036f2
refactor memoize to work with my objects
ferrine Mar 13, 2017
57eb4c1
add tests for numpy view usage
ferrine Mar 13, 2017
1e8d2ec
pickle-hash fix
ferrine Mar 14, 2017
d90af39
pickle-hash fix again
ferrine Mar 14, 2017
2b7f84e
add node sampling + make up some code
ferrine Mar 14, 2017
c4da035
add notebook with example
ferrine Mar 14, 2017
6747dc3
sample_proba explained
ferrine Mar 14, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
865 changes: 865 additions & 0 deletions docs/source/notebooks/bayesian_neural_network_opvi-advi.ipynb

Large diffs are not rendered by default.

117 changes: 117 additions & 0 deletions pymc3/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import theano.tensor as tt

from .special import gammaln
from ..math import logdet as _logdet

c = - 0.5 * np.log(2 * np.pi)


def bound(logp, *conditions, **kwargs):
Expand Down Expand Up @@ -96,3 +99,117 @@ def i1(x):
x**9 / 1474560 + x**11 / 176947200 + x**13 / 29727129600,
np.e**x / (2 * np.pi * x)**0.5 * (1 - 3 / (8 * x) + 15 / (128 * x**2) + 315 / (3072 * x**3)
+ 14175 / (98304 * x**4)))


def sd2rho(sd):
"""
`sd -> rho` theano converter
:math:`mu + sd*e = mu + log(1+exp(rho))*e`"""
return tt.log(tt.exp(sd) - 1)


def rho2sd(rho):
"""
`rho -> sd` theano converter
:math:`mu + sd*e = mu + log(1+exp(rho))*e`"""
return tt.log1p(tt.exp(rho))


def log_normal(x, mean, **kwargs):
"""
Calculate logarithm of normal distribution at point `x`
with given `mean` and `std`
Parameters
----------
x : Tensor
point of evaluation
mean : Tensor
mean of normal distribution
kwargs : one of parameters `{sd, tau, w, rho}`
Notes
-----
There are four variants for density parametrization.
They are:
1) standard deviation - `std`
2) `w`, logarithm of `std` :math:`w = log(std)`
3) `rho` that follows this equation :math:`rho = log(exp(std) - 1)`
4) `tau` that follows this equation :math:`tau = std^{-1}`
----
"""
sd = kwargs.get('sd')
w = kwargs.get('w')
rho = kwargs.get('rho')
tau = kwargs.get('tau')
eps = kwargs.get('eps', 0.0)
check = sum(map(lambda a: a is not None, [sd, w, rho, tau]))
if check > 1:
raise ValueError('more than one required kwarg is passed')
if check == 0:
raise ValueError('none of required kwarg is passed')
if sd is not None:
std = sd
elif w is not None:
std = tt.exp(w)
elif rho is not None:
std = rho2sd(rho)
else:
std = tau**(-1)
std += eps
return c - tt.log(tt.abs_(std)) - (x - mean) ** 2 / (2 * std ** 2)
Copy link
Contributor

@taku-y taku-y Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c is scaled with the number of elements in x: c -> c * x.ravel().shape[0] ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are elemwise operations, so that's ok here



def log_normal_mv(x, mean, gpu_compat=False, **kwargs):
"""
Calculate logarithm of normal distribution at point `x`
with given `mean` and `sigma` matrix
Parameters
----------
x : Tensor
point of evaluation
mean : Tensor
mean of normal distribution
kwargs : one of parameters `{cov, tau, chol}`

Flags
----------
gpu_compat : False, because LogDet is not GPU compatible yet.
If this is set as true, the GPU compatible (but numerically unstable) log(det) is used.

Notes
-----
There are three variants for density parametrization.
They are:
1) covariance matrix - `cov`
2) precision matrix - `tau`,
3) cholesky decomposition matrix - `chol`
----
"""
if gpu_compat:
def logdet(m):
return tt.log(tt.abs_(tt.nlinalg.det(m)))
else:
logdet = _logdet

T = kwargs.get('tau')
S = kwargs.get('cov')
L = kwargs.get('chol')
check = sum(map(lambda a: a is not None, [T, S, L]))
if check > 1:
raise ValueError('more than one required kwarg is passed')
if check == 0:
raise ValueError('none of required kwarg is passed')
# avoid unnecessary computations
if L is not None:
S = L.dot(L.T)
T = tt.nlinalg.matrix_inverse(S)
log_det = -logdet(S)
elif T is not None:
log_det = logdet(T)
else:
T = tt.nlinalg.matrix_inverse(S)
log_det = -logdet(S)
delta = x - mean
k = S.shape[0]
result = k * tt.log(2 * np.pi) - log_det
result += delta.dot(T).dot(delta)
return -1 / 2. * result
6 changes: 5 additions & 1 deletion pymc3/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,17 @@ def logit(p):
return tt.log(p / (1 - p))


def flatten_list(tensors):
return tt.concatenate([var.ravel() for var in tensors])


class LogDet(Op):
"""Computes the logarithm of absolute determinant of a square
matrix M, log(abs(det(M))), on CPU. Avoids det(M) overflow/
underflow.

Note: Once PR #3959 (https://github.com/Theano/Theano/pull/3959/) by harpone is merged,
this must be removed.
this must be removed.
"""
def make_node(self, x):
x = theano.tensor.as_tensor_variable(x)
Expand Down
17 changes: 13 additions & 4 deletions pymc3/memoize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import pickle


def memoize(obj):
Expand All @@ -23,8 +24,16 @@ def hashable(a):
Turn some unhashable objects into hashable ones.
"""
if isinstance(a, dict):
return hashable(a.items())
return hashable(tuple((hashable(a1), hashable(a2)) for a1, a2 in a.items()))
try:
return tuple(map(hashable, a))
except:
return a
return hash(a)
except TypeError:
pass
# Not hashable >>>
try:
return hash(pickle.dumps(a))
except Exception:
if hasattr(a, '__dict__'):
return hashable(a.__dict__)
else:
return id(a)
40 changes: 36 additions & 4 deletions pymc3/model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import collections
import threading
import six

import numpy as np
import scipy.sparse as sps
import theano
import theano.tensor as tt
import theano.sparse as sparse
from theano import theano, tensor as tt
from theano.tensor.var import TensorVariable

import pymc3 as pm
from pymc3.math import flatten_list
from .memoize import memoize
from .theanof import gradient, hessian, inputvars, generator
from .vartypes import typefilter, discrete_types, continuous_types, isgenerator
Expand All @@ -19,6 +20,8 @@
'Point', 'Deterministic', 'Potential'
]

FlatView = collections.namedtuple('FlatView', 'input, replacements, view')


class InstanceMethod(object):
"""Class for hiding references to instance methods so they can be pickled.
Expand Down Expand Up @@ -172,8 +175,10 @@ def fastd2logp(self, vars=None):
@property
def logpt(self):
"""Theano scalar of log-probability of the model"""

return tt.sum(self.logp_elemwiset) * self.scaling
if getattr(self, 'total_size', None) is not None:
return tt.sum(self.logp_elemwiset) * self.scaling
else:
return tt.sum(self.logp_elemwiset)

@property
def scaling(self):
Expand Down Expand Up @@ -659,6 +664,33 @@ def profile(self, outs, n=1000, point=None, profile=True, *args, **kwargs):

return f.profile

def flatten(self, vars=None):
"""Flattens model's input and returns:
FlatView with
* input vector variable
* replacements `input_var -> vars`
* view {variable: VarMap}

Parameters
----------
vars : list of variables or None
if None, then all model.free_RVs are used for flattening input

Returns
-------
flat_view
"""
if vars is None:
vars = self.free_RVs
order = ArrayOrdering(vars)
inputvar = tt.vector('flat_view', dtype=theano.config.floatX)
inputvar.tag.test_value = flatten_list(vars).tag.test_value
replacements = {self.named_vars[name]: inputvar[slc].reshape(shape).astype(dtype)
for name, slc, shape, dtype in order.vmap}
view = {vm.var: vm for vm in order.vmap}
flat_view = FlatView(inputvar, replacements, view)
return flat_view


def fn(outs, mode=None, model=None, *args, **kwargs):
"""Compiles a Theano function which returns the values of `outs` and
Expand Down
9 changes: 8 additions & 1 deletion pymc3/tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import unittest
import numpy.random as nr
from logging.handlers import BufferingHandler
import numpy.random as nr
from theano.sandbox.rng_mrg import MRG_RandomStreams
from ..theanof import set_tt_rng, tt_rng


class SeededTest(unittest.TestCase):
Expand All @@ -12,6 +14,11 @@ def setUpClass(cls):

def setUp(self):
nr.seed(self.random_seed)
self.old_tt_rng = tt_rng()
set_tt_rng(MRG_RandomStreams(self.random_seed))

def tearDown(self):
set_tt_rng(self.old_tt_rng)

class TestHandler(BufferingHandler):
def __init__(self, matcher):
Expand Down
1 change: 1 addition & 0 deletions pymc3/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ class TestLatentOccupancy(SeededTest):
Copyright (c) 2008 University of Otago. All rights reserved.
"""
def setUp(self):
super(TestLatentOccupancy, self).setUp()
# Sample size
n = 100
# True mean count, given occupancy
Expand Down
4 changes: 3 additions & 1 deletion pymc3/tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from pymc3.math import LogDet, logdet, probit, invprobit
from .helpers import SeededTest


def test_probit():
p = np.array([0.01, 0.25, 0.5, 0.75, 0.99])
np.testing.assert_allclose(invprobit(probit(p)).eval(), p, atol=1e-5)

class TestLogDet(SeededTest):

class TestLogDet(SeededTest):
def setUp(self):
super(TestLogDet, self).setUp()
utt.seed_rng()
self.op_class = LogDet
self.op = logdet
Expand Down
Loading