Skip to content

Vi summary #2230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 28, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion pymc3/tests/test_variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def test_fit(method, kwargs, error):
'ord',
[1, 2, np.inf]
)
def test_callbacks(diff, ord):
def test_callbacks_convergence(diff, ord):
cb = pm.variational.callbacks.CheckParametersConvergence(every=1, diff=diff, ord=ord)

class _approx:
Expand All @@ -406,3 +406,27 @@ class _approx:
with pytest.raises(StopIteration):
cb(approx, None, 1)
cb(approx, None, 10)


def test_tracker_callback():
import time
tracker = pm.callbacks.Tracker(
ints=lambda *t: t[-1],
ints2=lambda ap, h, j: j,
time=time.time,
)
for i in range(10):
tracker(None, None, i)
assert 'time' in tracker.hist
assert 'ints' in tracker.hist
assert 'ints2' in tracker.hist
assert (len(tracker['ints'])
== len(tracker['ints2'])
== len(tracker['time'])
== 10)
assert tracker['ints'] == tracker['ints2'] == list(range(10))
tracker = pm.callbacks.Tracker(
bad=lambda t: t # bad signature
)
with pytest.raises(TypeError):
tracker(None, None, 1)
67 changes: 66 additions & 1 deletion pymc3/variational/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import collections

import numpy as np

__all__ = [
'Callback',
'CheckParametersConvergence'
'CheckParametersConvergence',
'Tracker'
]


Expand Down Expand Up @@ -76,3 +79,65 @@ def __call__(self, approx, _, i):
@staticmethod
def flatten_shared(shared_list):
return np.concatenate([sh.get_value().flatten() for sh in shared_list])


class Tracker(Callback):
"""
Helper class to record arbitrary stats during VI

It is possible to pass a function that takes no arguments
If call fails then (approx, hist, i) are passed


Parameters
----------
kwargs : key word arguments
keys mapping statname to callable that records the stat

Examples
--------
Consider we want time on each iteration
>>> import time
>>> tracker = Tracker(time=time.time)
>>> with model:
... approx = pm.fit(callbacks=[tracker])

Time can be accessed via :code:`tracker['time']` now
For more complex summary one can use callable that takes
(approx, hist, i) as arguments
>>> with model:
... my_callable = lambda ap, h, i: h[-1]
... tracker = Tracker(some_stat=my_callable)
... approx = pm.fit(callbacks=[tracker])

Multiple stats are valid too
>>> with model:
... tracker = Tracker(some_stat=my_callable, time=time.time)
... approx = pm.fit(callbacks=[tracker])
"""
def __init__(self, **kwargs):
self.whatchdict = kwargs
self.hist = collections.defaultdict(list)

def record(self, approx, hist, i):
for key, fn in self.whatchdict.items():
try:
res = fn()
# if `*t` argument is used
# fail will be somehow detected.
# We want both calls to be tried.
# Upper one has more priority as
# arbitrary functions can have some
# defaults in positionals. Bad idea
# to try fn(approx, hist, i) first
except Exception:
res = fn(approx, hist, i)
self.hist[key].append(res)

def clear(self):
self.hist = collections.defaultdict(list)

def __getitem__(self, item):
return self.hist[item]

__call__ = record
5 changes: 2 additions & 3 deletions pymc3/variational/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

import logging
import warnings
import tqdm

import numpy as np
import tqdm

import pymc3 as pm
from pymc3.variational import test_functions
from pymc3.variational.approximations import MeanField, FullRank, Empirical
from pymc3.variational.operators import KL, KSD, AKSD
from pymc3.variational.opvi import Approximation
from pymc3.variational import test_functions


logger = logging.getLogger(__name__)

Expand Down