-
-
Notifications
You must be signed in to change notification settings - Fork 1
Add Forest dashboard #16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
4ff9eb0
class integration issue
yilinxia f94e580
fix multi dropdown
yilinxia 8349984
formalize the function
yilinxia c3bb9e0
first draft dashboard
yilinxia a042a52
comments addressed
yilinxia aafbbd0
fixed panel version & pymc issue
yilinxia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
import arviz as az | ||
import bokeh.io | ||
import panel as pn | ||
import param | ||
from IPython.display import display | ||
|
||
bokeh.io.reset_output() | ||
bokeh.io.output_notebook() | ||
|
||
pn.extension() | ||
|
||
|
||
class ModelVar(param.Parameterized): | ||
model = param.Selector("") | ||
data_variable = param.Selector("") | ||
coor_variable = param.Selector("") | ||
|
||
def __init__(self, idatas_cmp, **params) -> None: | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.idatas_cmp = idatas_cmp | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.default_model = list(self.idatas_cmp.keys())[0] | ||
self.param["model"].objects = list(self.idatas_cmp.keys()) | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.param["model"].default = self.default_model | ||
self.param["data_variable"].objects = list( | ||
self.idatas_cmp[self.default_model].posterior.data_vars.variables | ||
) | ||
super().__init__(**params) | ||
|
||
@param.depends("model", watch=True) | ||
def _update_data_variables(self): | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
data_variables = list( | ||
self.idatas_cmp[self.model].posterior.data_vars.variables | ||
) | ||
self.param["data_variable"].objects = data_variables | ||
if self.data_variable not in data_variables: | ||
self.data_variable = data_variables[0] | ||
|
||
@param.depends("data_variable", watch=True) | ||
def _update_coordinates(self): | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if ( | ||
self.idatas_cmp[self.model] | ||
.posterior.data_vars.variables[self.data_variable][0][0] | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.size | ||
> 1 | ||
): | ||
coor_variables = list( | ||
self.idatas_cmp[self.model].posterior.indexes["school"] | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
else: | ||
coor_variables = [""] | ||
self.param["coor_variable"].objects = coor_variables | ||
if self.coor_variable not in coor_variables: | ||
self.coor_variable = coor_variables[0] | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class ForestDashboard(ModelVar): | ||
def __init__(self, idatas_cmp) -> None: | ||
self.idatas_cmp = idatas_cmp | ||
super().__init__(self.idatas_cmp) | ||
|
||
def dashboard_forest(self): | ||
# define the widgets | ||
multi_select = pn.widgets.MultiSelect( | ||
name="ModelSelect", | ||
options=list(self.idatas_cmp.keys()), | ||
value=["mA"], | ||
) | ||
thre_slider = pn.widgets.FloatSlider( | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
name="HDI Probability", | ||
start=0, | ||
end=1, | ||
step=0.05, | ||
value=0.7, | ||
width=200, | ||
) | ||
truncate_checkbox = pn.widgets.Checkbox(name="Ridgeplot Truncate") | ||
ridge_quant = pn.widgets.RangeSlider( | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
name="Ridgeplot Quantiles", | ||
start=0, | ||
end=1, | ||
value=(0.25, 0.75), | ||
step=0.01, | ||
width=200, | ||
) | ||
op_slider = pn.widgets.FloatSlider( | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
name="Ridgeplot Overlap", | ||
start=0, | ||
end=1, | ||
step=0.05, | ||
value=0.7, | ||
width=200, | ||
) | ||
|
||
rope_slider = pn.widgets.RangeSlider( | ||
name="Rope Range", | ||
start=-10, | ||
end=10, | ||
value=(2, 5), | ||
step=1, | ||
width=200, | ||
) | ||
|
||
# construct widget | ||
@pn.depends( | ||
multi_select.param.value, | ||
thre_slider.param.value, | ||
rope_slider.param.value, | ||
self.param.data_variable, | ||
self.param.coor_variable, | ||
) | ||
def get_forest_plot( | ||
multi_select, thre_slider, rope_slider, | ||
data_variable, coor_variable | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
# generate graph | ||
data = [] | ||
for model_ in multi_select: | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
data.append(self.idatas_cmp[model_]) | ||
# add rope | ||
rope = {} | ||
school = {} | ||
school["school"] = coor_variable | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
school["rope"] = rope_slider | ||
rope[data_variable] = [school] | ||
# print(rope) | ||
forest_plt = az.plot_forest( | ||
data, | ||
model_names=multi_select, | ||
rope=rope, | ||
kind="forestplot", | ||
hdi_prob=thre_slider, | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
backend="bokeh", | ||
figsize=(9, 9), | ||
show=False, | ||
combined=True, | ||
colors="cycle", | ||
) | ||
return forest_plt[0][0] | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@pn.depends( | ||
multi_select.param.value, | ||
thre_slider.param.value, | ||
truncate_checkbox.param.value, | ||
ridge_quant.param.value, | ||
op_slider.param.value, | ||
) | ||
def get_ridge_plot( | ||
multi_select, | ||
thre_slider, | ||
truncate_checkbox, | ||
ridge_quant, | ||
op_slider, | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
# calculate the ridgeplot_quantiles | ||
temp_quant = list(ridge_quant) | ||
quant_ls = temp_quant | ||
quant_ls.sort() | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
avg_quant = sum(temp_quant) / 2 | ||
if quant_ls[0] < 0.5 and quant_ls[1] > 0.5: | ||
quant_ls.append(0.5) | ||
quant_ls.sort() | ||
else: | ||
quant_ls.append(avg_quant) | ||
quant_ls.sort() | ||
|
||
# generate graph | ||
data = [] | ||
for model_ in multi_select: | ||
data.append(self.idatas_cmp[model_]) | ||
|
||
ridge_plt = az.plot_forest( | ||
data, | ||
model_names=multi_select, | ||
kind="ridgeplot", | ||
hdi_prob=thre_slider, | ||
ridgeplot_truncate=truncate_checkbox, | ||
ridgeplot_quantiles=quant_ls, | ||
ridgeplot_overlap=op_slider, | ||
backend="bokeh", | ||
figsize=(9, 9), | ||
show=False, | ||
combined=True, | ||
colors="white", | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
return ridge_plt[0][0] | ||
|
||
plot_result_1 = pn.Column( | ||
pn.WidgetBox( | ||
"add rope", | ||
pn.Row( | ||
self.param.model, | ||
self.param.data_variable, | ||
self.param.coor_variable, | ||
), | ||
rope_slider, | ||
), | ||
get_forest_plot, | ||
) | ||
plot_result_2 = pn.Column( | ||
pn.Row(truncate_checkbox), | ||
pn.Row(ridge_quant, op_slider), | ||
get_ridge_plot, | ||
) | ||
# show up | ||
display( | ||
yilinxia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pn.Column( | ||
pn.Row(multi_select), | ||
thre_slider, | ||
# pn.Row(self.param), | ||
pn.Tabs( | ||
("Forest_Plot", plot_result_1), | ||
( | ||
"Rdiget_Plot", | ||
plot_result_2, | ||
), | ||
), | ||
).servable(), | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.