-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Submit BEST #1517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Submit BEST #1517
Changes from all commits
b20886c
78ab907
93d6385
2ce1c30
f150458
360c8c6
92b4cb7
872337f
b69c9dc
b3cffda
c0fd39b
7595681
e62e58e
ee33707
183ee9f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
""" | ||
Written by: Eric J. Ma | ||
Date: 11 November 2016 | ||
Inspiration taken from many places, including the PyMC3 documentation. | ||
|
||
A note on the API design, for future contributors. | ||
|
||
There were some design choices made here for "default models" that may be | ||
modified in the future. I list the choices below: | ||
|
||
- A "model" is an object, like scikit-learn. | ||
- A model is instantiated with a DataFrame that houses the data. | ||
- Models accept other parameters as necessary. | ||
- Every model has a `.fit()` function that performs model fitting, like in | ||
scikit-learn. | ||
- Every model has a `.plot_posterior()` function that returns a figure showing | ||
the posterior distribution. Inspired by the pymc3 GLM module. | ||
- BEST uses ADVI, but this can (and should) be made an option; MCMC is also a | ||
good tool to use. | ||
""" | ||
|
||
from ..distributions import StudentT, Exponential, Uniform, HalfCauchy | ||
from .. import Model | ||
from ..variational import advi, sample_vp | ||
import seaborn as sns | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
|
||
|
||
class BEST(object): | ||
"""BEST Model, based on Kruschke (2013). | ||
|
||
Parameters | ||
---------- | ||
data : pandas DataFrame | ||
A pandas dataframe which has the following data: | ||
- Each row is one replicate measurement. | ||
- There is a column that records the treatment name. | ||
- There is a column that records the measured value for that replicate. | ||
|
||
sample_col : str | ||
The name of the column containing sample names. | ||
|
||
output_col : str | ||
The name of the column containing values to estimate. | ||
|
||
baseline_name : str | ||
The name of the "control" or "baseline". | ||
|
||
Output | ||
------ | ||
model : PyMC3 model | ||
Returns the BEST model containing | ||
""" | ||
def __init__(self, data, sample_col, output_col, baseline_name): | ||
super(BEST, self).__init__() | ||
self.data = data | ||
self.sample_col = sample_col | ||
self.output_col = output_col | ||
self.baseline_name = baseline_name | ||
self.trace = None | ||
|
||
self._convert_to_indices() | ||
|
||
def _convert_to_indices(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is better to write static function without side effects. BTW it seems to be modifying original dataframe, that's bad for user
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, the side effects here concern me. |
||
""" | ||
Adds the "indices" column to self.data (DataFrame). This is necessary | ||
for the simplified model specification in the "fit" function below. | ||
""" | ||
sample_names = dict() | ||
for i, name in enumerate( | ||
list(np.unique(self.data[self.sample_col].values))): | ||
print('Sample name {0} has the index {1}'.format(name, i)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest using the |
||
sample_names[name] = i | ||
self.data['indices'] = self.data[self.sample_col].apply( | ||
lambda x: sample_names[x]) | ||
|
||
def fit(self, n_steps=500000): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is not possible to custimize advifit:( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I plan to change this to use @twiecki's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's a good idea, it is afterall what @twiecki's |
||
""" | ||
Creates a Bayesian Estimation model for replicate measurements of | ||
treatment(s) vs. control. | ||
|
||
Parameters | ||
---------- | ||
n_steps : int | ||
The number of steps to run ADVI. | ||
""" | ||
|
||
sample_names = set(self.data[self.sample_col].values) | ||
|
||
mean_test = self.data.groupby('indices').mean()[self.output_col].values | ||
sd_test = self.data.groupby('indices').std()[self.output_col].values | ||
|
||
with Model() as model: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This adds limitations to usage, I will not be possible to use model within a model. Am I right, @twiecki? |
||
# Hyperpriors | ||
upper = Exponential('upper', lam=0.05) | ||
nu = Exponential('nu_minus_one', 1/29.) + 1 | ||
|
||
# "fold", which is the estimated fold change. | ||
fold = Uniform('fold', lower=1E-10, upper=upper, | ||
shape=len(sample_names), testval=mean_test) | ||
|
||
# Assume that data have heteroskedastic (i.e. variable) error but | ||
# are drawn from the same HalfCauchy distribution. | ||
sigma = HalfCauchy('sigma', beta=1, shape=len(sample_names), | ||
testval=sd_test) | ||
|
||
# Model prediction | ||
mu = fold[self.data['indices']] | ||
sig = sigma[self.data['indices']] | ||
|
||
# Data likelihood | ||
like = StudentT('like', nu=nu, mu=mu, sd=sig**-2, | ||
observed=self.data[self.output_col]) | ||
|
||
params = advi(n=n_steps) | ||
trace = sample_vp(params, draws=2000) | ||
|
||
self.trace = trace | ||
self.params = params | ||
self.model = model | ||
|
||
def plot_posterior(self, rotate_xticks=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should also return axes for further user modifications There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great point, I'll update the PR with a fix for this. |
||
""" | ||
Plots a swarm plot of the data overlaid on top of the 95% HPD and IQR | ||
of the posterior distribution. | ||
""" | ||
|
||
# Make summary plot # | ||
fig = plt.figure() | ||
ax = fig.add_subplot(111) | ||
|
||
# 1. Get the lower error and upper errorbars for 95% HPD and IQR. | ||
lower, lower_q, upper_q, upper = np.percentile(self.trace['fold'], | ||
[2.5, 25, 75, 97.5], | ||
axis=0) | ||
summary_stats = pd.DataFrame() | ||
summary_stats['mean'] = self.trace['fold'].mean(axis=0) | ||
err_low = summary_stats['mean'] - lower | ||
err_high = upper - summary_stats['mean'] | ||
iqr_low = summary_stats['mean'] - lower_q | ||
iqr_high = upper_q - summary_stats['mean'] | ||
|
||
# 2. Plot the swarmplot and errorbars. | ||
summary_stats['mean'].plot(ls='', ax=ax, | ||
yerr=[err_low, err_high]) | ||
summary_stats['mean'].plot(ls='', ax=ax, | ||
yerr=[iqr_low, iqr_high], | ||
elinewidth=4, color='red') | ||
sns.swarmplot(data=self.data, x=self.sample_col, y=self.output_col, | ||
ax=ax, alpha=0.5) | ||
|
||
if rotate_xticks: | ||
print('rotating xticks') | ||
plt.xticks(rotation='vertical') | ||
plt.ylabel(self.output_col) | ||
|
||
return fig, ax | ||
|
||
def plot_elbo(self): | ||
""" | ||
Plots the ELBO values to help check for convergence. | ||
""" | ||
fig = plt.figure() | ||
plt.plot(-np.log10(-self.params.elbo_vals)) | ||
|
||
return fig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As for base class we should consider using lasagne base ideas. I'll do a sketch for that soon