diff --git a/docs/source/notebooks/live_sample_plots.ipynb b/docs/source/notebooks/live_sample_plots.ipynb new file mode 100644 index 0000000000..e851acea72 --- /dev/null +++ b/docs/source/notebooks/live_sample_plots.ipynb @@ -0,0 +1,122 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "# Live sample plots" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "deletable": true, + "editable": true + }, + "source": [ + "This notebook illustrates how we can have live sample plots when calling the `sample` function with `live_plot=True`. It is based on the \"Coal mining disasters\" case study in the [Getting started notebook](https://github.com/pymc-devs/pymc3/blob/master/docs/source/notebooks/getting_started.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from pymc3 import Model, Exponential, DiscreteUniform, Poisson, sample\n", + "from pymc3.math import switch\n", + "\n", + "%matplotlib notebook" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "disaster_data = np.ma.masked_values([4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,\n", + " 3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,\n", + " 2, 2, 3, 4, 2, 1, 3, -999, 2, 1, 1, 1, 1, 3, 0, 0,\n", + " 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,\n", + " 0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,\n", + " 3, 3, 1, -999, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,\n", + " 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1], value=-999)\n", + "year = np.arange(1851, 1962)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "with Model() as disaster_model:\n", + "\n", + " switchpoint = DiscreteUniform('switchpoint', lower=year.min(), upper=year.max(), testval=1900)\n", + "\n", + " # Priors for pre- and post-switch rates number of disasters\n", + " early_rate = Exponential('early_rate', 1)\n", + " late_rate = Exponential('late_rate', 1)\n", + "\n", + " # Allocate appropriate Poisson rates to years before and after current\n", + " rate = switch(switchpoint >= year, early_rate, late_rate)\n", + "\n", + " disasters = Poisson('disasters', rate, observed=disaster_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true, + "scrolled": false + }, + "outputs": [], + "source": [ + "with disaster_model:\n", + " trace = sample(10000, live_plot=True, skip_first=100, refresh_every=300, roll_over=1000)" + ] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python [default]", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pymc3/plots/traceplot.py b/pymc3/plots/traceplot.py index 8ccbca0900..6197b61af8 100644 --- a/pymc3/plots/traceplot.py +++ b/pymc3/plots/traceplot.py @@ -7,7 +7,8 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None, lines=None, combined=False, plot_transformed=False, grid=False, alpha=0.35, priors=None, - prior_alpha=1, prior_style='--', ax=None): + prior_alpha=1, prior_style='--', ax=None, live_plot=False, + skip_first=0, refresh_every=100, roll_over=1000): """Plot samples histograms and values. Parameters @@ -45,6 +46,16 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None, Line style for prior plot. Defaults to '--' (dashed line). ax : axes Matplotlib axes. Accepts an array of axes, e.g.: + live_plot: bool + Flag for updating the current figure while sampling + skip_first : int + Number of first samples not shown in plots (burn-in). This affects + frequency and stream plots. + refresh_every : int + Period of plot updates (in sample number) + roll_over : int + Width of the sliding window for the sample stream plots: last roll_over + samples are shown (no effect on frequency plots). >>> fig, axs = plt.subplots(3, 2) # 3 RVs >>> pymc3.traceplot(trace, ax=axs) @@ -57,6 +68,8 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None, ax : matplotlib axes """ + trace = trace[skip_first:] + if varnames is None: varnames = get_default_varnames(trace, plot_transformed) @@ -70,9 +83,23 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None, prior = priors[i] else: prior = None + first_time = True for d in trace.get_values(v, combine=combined, squeeze=False): d = np.squeeze(transform(d)) d = make_2d(d) + d_stream = d + x0 = 0 + if live_plot: + x0 = skip_first + if first_time: + ax[i, 0].cla() + ax[i, 1].cla() + first_time = False + if roll_over is not None: + if len(d) >= roll_over: + x0 = len(d) - roll_over + skip_first + d_stream = d[-roll_over:] + width = len(d_stream) if d.dtype.kind == 'i': hist_objs = histplot_op(ax[i, 0], d, alpha=alpha) colors = [h[-1][0].get_facecolor() for h in hist_objs] @@ -82,7 +109,7 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None, ax[i, 0].set_title(str(v)) ax[i, 0].grid(grid) ax[i, 1].set_title(str(v)) - ax[i, 1].plot(d, alpha=alpha) + ax[i, 1].plot(range(x0, x0 + width), d_stream, alpha=alpha) ax[i, 0].set_ylabel("Frequency") ax[i, 1].set_ylabel("Sample value") @@ -103,6 +130,13 @@ def traceplot(trace, varnames=None, transform=identity_transform, figsize=None, lw=1.5, alpha=alpha) except KeyError: pass + if live_plot: + for j in [0, 1]: + ax[i, j].relim() + ax[i, j].autoscale_view(True, True, True) + ax[i, 1].set_xlim(x0, x0 + width) ax[i, 0].set_ylim(ymin=0) + if live_plot: + ax[0, 0].figure.canvas.draw() plt.tight_layout() return ax diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 32f022c489..dc4614bf3b 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -11,6 +11,8 @@ from .step_methods import (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis, BinaryGibbsMetropolis, CategoricalGibbsMetropolis, Slice, CompoundStep) +from .plots.utils import identity_transform +from .plots.traceplot import traceplot from tqdm import tqdm import warnings @@ -85,7 +87,7 @@ def assign_step_methods(model, step=None, methods=(NUTS, HamiltonianMC, Metropol def sample(draws, step=None, init='ADVI', n_init=200000, start=None, trace=None, chain=0, njobs=1, tune=None, progressbar=True, - model=None, random_seed=-1): + model=None, random_seed=-1, live_plot=False, **kwargs): """Draw samples from the posterior using the given step methods. Multiple step methods are supported via compound step methods. @@ -141,6 +143,8 @@ def sample(draws, step=None, init='ADVI', n_init=200000, start=None, model : Model (optional if in `with` context) random_seed : int or list of ints A list is accepted if more if `njobs` is greater than one. + live_plot: bool + Flag for live plotting the trace while sampling Returns ------- @@ -175,7 +179,9 @@ def sample(draws, step=None, init='ADVI', n_init=200000, start=None, 'tune': tune, 'progressbar': progressbar, 'model': model, - 'random_seed': random_seed} + 'random_seed': random_seed, + 'live_plot': live_plot, + **kwargs} if njobs > 1: sample_func = _mp_sample @@ -187,15 +193,27 @@ def sample(draws, step=None, init='ADVI', n_init=200000, start=None, def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None, - progressbar=True, model=None, random_seed=-1): + progressbar=True, model=None, random_seed=-1, live_plot=False, + **kwargs): + live_plot_args = {'skip_first': 0, 'refresh_every': 100} + live_plot_args = {arg: kwargs[arg] if arg in kwargs else live_plot_args[arg] for arg in live_plot_args} + skip_first = live_plot_args['skip_first'] + refresh_every = live_plot_args['refresh_every'] + sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed) if progressbar: sampling = tqdm(sampling, total=draws) try: strace = None - for strace in sampling: - pass + for it, strace in enumerate(sampling): + if live_plot: + if it >= skip_first: + trace = MultiTrace([strace]) + if it == skip_first: + ax = traceplot(trace, live_plot=False, **kwargs) + elif (it - skip_first) % refresh_every == 0 or it == draws - 1: + traceplot(trace, ax=ax, live_plot=True, **kwargs) except KeyboardInterrupt: pass finally: