Skip to content

Commit dbf0654

Browse files
authored
ENH Add ADVI initializing for continuous models sampled using NUTS.
2 parents ad7dc8a + 59f500c commit dbf0654

15 files changed

+572
-529
lines changed

docs/source/notebooks/BEST.ipynb

Lines changed: 25 additions & 58 deletions
Large diffs are not rendered by default.

docs/source/notebooks/GLM-hierarchical.ipynb

Lines changed: 44 additions & 55 deletions
Large diffs are not rendered by default.

docs/source/notebooks/LKJ.ipynb

Lines changed: 36 additions & 64 deletions
Large diffs are not rendered by default.

docs/source/notebooks/NUTS_scaling_using_ADVI.ipynb

Lines changed: 173 additions & 81 deletions
Large diffs are not rendered by default.

docs/source/notebooks/cox_model.ipynb

Lines changed: 46 additions & 45 deletions
Large diffs are not rendered by default.

docs/source/notebooks/marginalized_gaussian_mixture_model.ipynb

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,9 @@
292292
}
293293
],
294294
"metadata": {
295+
"anaconda-cloud": {},
295296
"kernelspec": {
296-
"display_name": "Python 3",
297+
"display_name": "Python [default]",
297298
"language": "python",
298299
"name": "python3"
299300
},
@@ -307,11 +308,7 @@
307308
"name": "python",
308309
"nbconvert_exporter": "python",
309310
"pygments_lexer": "ipython3",
310-
"version": "3.5.1"
311-
},
312-
"widgets": {
313-
"state": {},
314-
"version": "1.1.2"
311+
"version": "3.5.2"
315312
}
316313
},
317314
"nbformat": 4,

docs/source/notebooks/pmf-pymc.ipynb

Lines changed: 47 additions & 48 deletions
Large diffs are not rendered by default.

docs/source/notebooks/posterior_predictive.ipynb

Lines changed: 58 additions & 55 deletions
Large diffs are not rendered by default.

docs/source/notebooks/stochastic_volatility.ipynb

Lines changed: 31 additions & 107 deletions
Large diffs are not rendered by default.

pymc3/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,16 @@ def as_iterargs(data):
703703
else:
704704
return [data]
705705

706+
707+
def all_continuous(vars):
708+
"""Check that vars not include discrete variables, excepting ObservedRVs.
709+
"""
710+
vars_ = [var for var in vars if not isinstance(var, pm.model.ObservedRV)]
711+
if any([var.dtype in pm.discrete_types for var in vars_]):
712+
return False
713+
else:
714+
return True
715+
706716
# theano stuff
707717
theano.config.warn.sum_div_dimshuffle_bug = False
708718
theano.config.compute_test_value = 'raise'

0 commit comments

Comments
 (0)