Skip to content

Added live_traceplot function #1934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 56 commits into from
Mar 27, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
9d5b20a
Added live_traceplot function
davidbrochart Mar 22, 2017
616a876
Cosmetic change
davidbrochart Mar 22, 2017
6f54f81
Merge branch 'master' into live_sample_plots
davidbrochart Mar 23, 2017
de127b8
Changed the API to pm.sample(..., live_plot=True)
davidbrochart Mar 23, 2017
bdd25af
Add tutorial to detect sampling problems (#1866)
Mar 7, 2017
8af47d8
DOC Change heading names.
twiecki Mar 7, 2017
07b3d15
Make install scripts idempotent (#1879)
ColCarroll Mar 7, 2017
bc62f3e
Add examples of censored data models (#1870)
Mar 7, 2017
6261937
Don't include `-np.inf` in calculating average ELBO (#1880)
pstjohn Mar 7, 2017
5417f0c
Raise TypeError on non-data values of observed (#1872)
fonnesbeck Mar 8, 2017
e6317ac
Make exponential mode have the correct shape
AustinRochford Mar 8, 2017
04d3fba
Added tutorial notebook on updating priors
davidbrochart Mar 7, 2017
ebb48c9
Made small changes and executed the notebook
davidbrochart Mar 8, 2017
a42e25c
Fixed y-axis bug in forestplot; added transform argument to summary
fonnesbeck Mar 7, 2017
e125a71
Style cleanup
fonnesbeck Mar 8, 2017
28022c3
Added probit and invprobit functions
fonnesbeck Mar 8, 2017
c5fb96e
Added carriage return to end of file
fonnesbeck Mar 8, 2017
695b49b
Fixed indentation
fonnesbeck Mar 8, 2017
caa5318
Changed probit test to use assert_allclose
fonnesbeck Mar 8, 2017
c122885
Fix support of LKJCorr
aseyboldt Mar 5, 2017
e3cf77b
Fix tests for LKJCorr
aseyboldt Mar 9, 2017
a3825d7
Added warning for ignoring init arguments in sample
fonnesbeck Mar 10, 2017
6fcee15
Kill stray tab
fonnesbeck Mar 10, 2017
030205e
Improve performance of transformations
aseyboldt Mar 11, 2017
de6c5a0
DOC Add new features
twiecki Mar 13, 2017
a47f27c
Bump version.
twiecki Mar 13, 2017
416e6f2
WIP: Implement opvi (#1694)
ferrine Mar 15, 2017
966b1ed
delete unnecessary text and add some benchmarks (#1901)
ferrine Mar 16, 2017
8e0c489
doc(DiagInferDiv): formatting fix in blog post quote. Closes #1895. (…
alexandercbooth Mar 16, 2017
c05282b
Revert "small fix for multivariate mixture models"
AustinRochford Mar 15, 2017
6c7c600
Added message about init only working with auto-assigned step methods
fonnesbeck Mar 15, 2017
de77fca
Added docs and scripts to MANIFEST
fonnesbeck Mar 14, 2017
ce3a5e6
Added newline to MANIFEST
fonnesbeck Mar 14, 2017
9025d7f
Replaced package list with find_packages in setup.py; removed example…
fonnesbeck Mar 14, 2017
75da69c
Updated version to rc2
fonnesbeck Mar 16, 2017
d6bd886
Fixed stray version string
fonnesbeck Mar 16, 2017
47f6e1b
refactor variational module, add histogram approximation (#1904)
ferrine Mar 17, 2017
0dbd346
Fix indexing traces with steps greater one
aseyboldt Mar 16, 2017
3a0d654
SVGD problems (#1916)
ferrine Mar 18, 2017
31d514d
Histogram docs (#1914)
ferrine Mar 18, 2017
4240ea5
improve aesthetics
aloctavodia Mar 18, 2017
f6190fc
Bump theano to 0.9.0rc4 (#1921)
ColCarroll Mar 20, 2017
96ca0ac
Histogram: use only free RVs from trace (#1926)
ferrine Mar 21, 2017
33406e8
small fix to prevent a TypeError with the ufunc true_divide
aloctavodia Mar 21, 2017
6f984b3
Bump theano to be at least 0.9.0
ColCarroll Mar 21, 2017
4d531be
Add LKJCholeskyCov
aseyboldt Mar 14, 2017
6a718fa
Fix log jacobian in LKJCholeskyCov
aseyboldt Mar 15, 2017
34b716f
Add documentation for LKJCholeskyCov
aseyboldt Mar 17, 2017
d8566f9
Add expand_packed_triangular
aseyboldt Mar 18, 2017
62835ae
Add tests for LKJCholeskyCov
aseyboldt Mar 20, 2017
69aab74
Fix tests for py2
aseyboldt Mar 22, 2017
8326741
Add floatX wrappers in test_advi
kyleabeauchamp Mar 23, 2017
45d0887
Changed the API to pm.sample(..., live_plot=True)
davidbrochart Mar 23, 2017
44b12e6
Merge branch 'live_sample_plots' of https://github.com/davidbrochart/…
davidbrochart Mar 27, 2017
0f782cf
Better formatting
davidbrochart Mar 27, 2017
9083ac8
Merged with lastest master
davidbrochart Mar 27, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions docs/source/notebooks/live_sample_plots.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Live sample plots"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"This notebook illustrates how we can have live sample plots when calling the `sample` function with `live_plot=True`. It is based on the \"Coal mining disasters\" case study in the [Getting started notebook](https://github.com/pymc-devs/pymc3/blob/master/docs/source/notebooks/getting_started.ipynb)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from pymc3 import Model, Exponential, DiscreteUniform, Poisson, sample\n",
"from pymc3.math import switch\n",
"\n",
"%matplotlib notebook"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"disaster_data = np.ma.masked_values([4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,\n",
" 3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,\n",
" 2, 2, 3, 4, 2, 1, 3, -999, 2, 1, 1, 1, 1, 3, 0, 0,\n",
" 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,\n",
" 0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,\n",
" 3, 3, 1, -999, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,\n",
" 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1], value=-999)\n",
"year = np.arange(1851, 1962)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"with Model() as disaster_model:\n",
"\n",
" switchpoint = DiscreteUniform('switchpoint', lower=year.min(), upper=year.max(), testval=1900)\n",
"\n",
" # Priors for pre- and post-switch rates number of disasters\n",
" early_rate = Exponential('early_rate', 1)\n",
" late_rate = Exponential('late_rate', 1)\n",
"\n",
" # Allocate appropriate Poisson rates to years before and after current\n",
" rate = switch(switchpoint >= year, early_rate, late_rate)\n",
"\n",
" disasters = Poisson('disasters', rate, observed=disaster_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true,
"scrolled": false
},
"outputs": [],
"source": [
"with disaster_model:\n",
" trace = sample(10000, live_plot=True, skip_first=100, refresh_every=300, roll_over=1000)"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
38 changes: 36 additions & 2 deletions pymc3/plots/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

def traceplot(trace, varnames=None, transform=identity_transform, figsize=None, lines=None,
combined=False, plot_transformed=False, grid=False, alpha=0.35, priors=None,
prior_alpha=1, prior_style='--', ax=None):
prior_alpha=1, prior_style='--', ax=None, live_plot=False,
skip_first=0, refresh_every=100, roll_over=1000):
"""Plot samples histograms and values.

Parameters
Expand Down Expand Up @@ -45,6 +46,16 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None,
Line style for prior plot. Defaults to '--' (dashed line).
ax : axes
Matplotlib axes. Accepts an array of axes, e.g.:
live_plot: bool
Flag for updating the current figure while sampling
skip_first : int
Number of first samples not shown in plots (burn-in). This affects
frequency and stream plots.
refresh_every : int
Period of plot updates (in sample number)
roll_over : int
Width of the sliding window for the sample stream plots: last roll_over
samples are shown (no effect on frequency plots).

>>> fig, axs = plt.subplots(3, 2) # 3 RVs
>>> pymc3.traceplot(trace, ax=axs)
Expand All @@ -57,6 +68,8 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None,
ax : matplotlib axes

"""
trace = trace[skip_first:]

if varnames is None:
varnames = get_default_varnames(trace, plot_transformed)

Expand All @@ -70,9 +83,23 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None,
prior = priors[i]
else:
prior = None
first_time = True
for d in trace.get_values(v, combine=combined, squeeze=False):
d = np.squeeze(transform(d))
d = make_2d(d)
d_stream = d
x0 = 0
if live_plot:
x0 = skip_first
if first_time:
ax[i, 0].cla()
ax[i, 1].cla()
first_time = False
if roll_over is not None:
if len(d) >= roll_over:
x0 = len(d) - roll_over + skip_first
d_stream = d[-roll_over:]
width = len(d_stream)
if d.dtype.kind == 'i':
hist_objs = histplot_op(ax[i, 0], d, alpha=alpha)
colors = [h[-1][0].get_facecolor() for h in hist_objs]
Expand All @@ -82,7 +109,7 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None,
ax[i, 0].set_title(str(v))
ax[i, 0].grid(grid)
ax[i, 1].set_title(str(v))
ax[i, 1].plot(d, alpha=alpha)
ax[i, 1].plot(range(x0, x0 + width), d_stream, alpha=alpha)

ax[i, 0].set_ylabel("Frequency")
ax[i, 1].set_ylabel("Sample value")
Expand All @@ -103,6 +130,13 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None,
lw=1.5, alpha=alpha)
except KeyError:
pass
if live_plot:
for j in [0, 1]:
ax[i, j].relim()
ax[i, j].autoscale_view(True, True, True)
ax[i, 1].set_xlim(x0, x0 + width)
ax[i, 0].set_ylim(ymin=0)
if live_plot:
ax[0, 0].figure.canvas.draw()
plt.tight_layout()
return ax
28 changes: 23 additions & 5 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .step_methods import (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis,
BinaryGibbsMetropolis, CategoricalGibbsMetropolis,
Slice, CompoundStep)
from .plots.utils import identity_transform
from .plots.traceplot import traceplot
from tqdm import tqdm

import warnings
Expand Down Expand Up @@ -85,7 +87,7 @@ def assign_step_methods(model, step=None, methods=(NUTS, HamiltonianMC, Metropol

def sample(draws, step=None, init='ADVI', n_init=200000, start=None,
trace=None, chain=0, njobs=1, tune=None, progressbar=True,
model=None, random_seed=-1):
model=None, random_seed=-1, live_plot=False, **kwargs):
"""Draw samples from the posterior using the given step methods.

Multiple step methods are supported via compound step methods.
Expand Down Expand Up @@ -141,6 +143,8 @@ def sample(draws, step=None, init='ADVI', n_init=200000, start=None,
model : Model (optional if in `with` context)
random_seed : int or list of ints
A list is accepted if more if `njobs` is greater than one.
live_plot: bool
Flag for live plotting the trace while sampling

Returns
-------
Expand Down Expand Up @@ -175,7 +179,9 @@ def sample(draws, step=None, init='ADVI', n_init=200000, start=None,
'tune': tune,
'progressbar': progressbar,
'model': model,
'random_seed': random_seed}
'random_seed': random_seed,
'live_plot': live_plot,
**kwargs}

if njobs > 1:
sample_func = _mp_sample
Expand All @@ -187,15 +193,27 @@ def sample(draws, step=None, init='ADVI', n_init=200000, start=None,


def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
progressbar=True, model=None, random_seed=-1):
progressbar=True, model=None, random_seed=-1, live_plot=False,
**kwargs):
live_plot_args = {'skip_first': 0, 'refresh_every': 100}
live_plot_args = {arg: kwargs[arg] if arg in kwargs else live_plot_args[arg] for arg in live_plot_args}
skip_first = live_plot_args['skip_first']
refresh_every = live_plot_args['refresh_every']

sampling = _iter_sample(draws, step, start, trace, chain,
tune, model, random_seed)
if progressbar:
sampling = tqdm(sampling, total=draws)
try:
strace = None
for strace in sampling:
pass
for it, strace in enumerate(sampling):
if live_plot:
if it >= skip_first:
trace = MultiTrace([strace])
if it == skip_first:
ax = traceplot(trace, live_plot=False, **kwargs)
elif (it - skip_first) % refresh_every == 0 or it == draws - 1:
traceplot(trace, ax=ax, live_plot=True, **kwargs)
except KeyboardInterrupt:
pass
finally:
Expand Down