Skip to content

Commit 055f666

Browse files
committed
DOC Add better docs and add advi_map.
1 parent 9a03460 commit 055f666

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

pymc3/sampling.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -386,18 +386,22 @@ def sample_init(draws=2000, init='advi', n_init=500000, sampler='nuts',
386386
387387
Parameteres
388388
-----------
389-
init : str {'advi', 'map', 'nuts'}
389+
draws : int
390+
Number of posterior samples to draw.
391+
init : str {'advi', 'advi_map', 'map', 'nuts'}
390392
Initialization method to use.
393+
* advi : Run ADVI to estimate posterior mean and diagonal covariance matrix.
394+
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
395+
* map : Use the MAP as starting point.
396+
* nuts : Run NUTS and estimate posterior mean and covariance matrix.
391397
n_init : int
392398
Number of iterations of initializer
393399
If 'advi', number of iterations, if 'metropolis', number of draws.
394400
sampler : str {'nuts', 'hmc', advi'}
395401
Sampler to use. Will be initialized using init algorithm.
396-
draws : int
397-
Number of posterior samples to draw.
398-
njobs : int
399-
Number of parallel jobs to start. If None, set to number of cpus
400-
in the system - 2.
402+
* nuts : Run NUTS sampler with the init covariance estimation as the scaling matrix.
403+
* hmc : Run HamiltonianMC sampler with the init covariance estimation as the scaling matrix.
404+
* advi : Sample from variational posterior, requires init='advi'.
401405
**kwargs : additional keyword argumemts
402406
Additional keyword argumemts are forwared to pymc3.sample()
403407
@@ -411,9 +415,12 @@ def sample_init(draws=2000, init='advi', n_init=500000, sampler='nuts',
411415

412416
if init == 'advi':
413417
v_params = pm.variational.advi(n=n_init)
414-
start = v_params.means
418+
start = pm.variational.sample_vp(v_params, 1)[0]
419+
cov = np.power(model.dict_to_array(v_params.stds), 2)
420+
elif init == 'advi_map':
421+
start = pm.find_MAP()
422+
v_params = pm.variational.advi(n=n_init, start=start)
415423
cov = np.power(model.dict_to_array(v_params.stds), 2)
416-
417424
elif init == 'map':
418425
start = pm.find_MAP()
419426
cov = pm.find_hessian(point=start)

pymc3/tests/test_sampling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_sample(self):
6363

6464
def test_sample_init(self):
6565
with self.model:
66-
for init in ('advi', 'map', 'nuts'):
66+
for init in ('advi', 'advi_map', 'map', 'nuts'):
6767
for sampler in ('nuts', 'hmc', 'advi'):
6868
if (sampler == 'advi') and (init != 'advi'):
6969
self.assertRaises(ValueError, pm.sample_init,

0 commit comments

Comments
 (0)