@@ -386,18 +386,22 @@ def sample_init(draws=2000, init='advi', n_init=500000, sampler='nuts',
386
386
387
387
Parameteres
388
388
-----------
389
- init : str {'advi', 'map', 'nuts'}
389
+ draws : int
390
+ Number of posterior samples to draw.
391
+ init : str {'advi', 'advi_map', 'map', 'nuts'}
390
392
Initialization method to use.
393
+ * advi : Run ADVI to estimate posterior mean and diagonal covariance matrix.
394
+ * advi_map: Initialize ADVI with MAP and use MAP as starting point.
395
+ * map : Use the MAP as starting point.
396
+ * nuts : Run NUTS and estimate posterior mean and covariance matrix.
391
397
n_init : int
392
398
Number of iterations of initializer
393
399
If 'advi', number of iterations, if 'metropolis', number of draws.
394
400
sampler : str {'nuts', 'hmc', advi'}
395
401
Sampler to use. Will be initialized using init algorithm.
396
- draws : int
397
- Number of posterior samples to draw.
398
- njobs : int
399
- Number of parallel jobs to start. If None, set to number of cpus
400
- in the system - 2.
402
+ * nuts : Run NUTS sampler with the init covariance estimation as the scaling matrix.
403
+ * hmc : Run HamiltonianMC sampler with the init covariance estimation as the scaling matrix.
404
+ * advi : Sample from variational posterior, requires init='advi'.
401
405
**kwargs : additional keyword argumemts
402
406
Additional keyword argumemts are forwared to pymc3.sample()
403
407
@@ -411,9 +415,12 @@ def sample_init(draws=2000, init='advi', n_init=500000, sampler='nuts',
411
415
412
416
if init == 'advi' :
413
417
v_params = pm .variational .advi (n = n_init )
414
- start = v_params .means
418
+ start = pm .variational .sample_vp (v_params , 1 )[0 ]
419
+ cov = np .power (model .dict_to_array (v_params .stds ), 2 )
420
+ elif init == 'advi_map' :
421
+ start = pm .find_MAP ()
422
+ v_params = pm .variational .advi (n = n_init , start = start )
415
423
cov = np .power (model .dict_to_array (v_params .stds ), 2 )
416
-
417
424
elif init == 'map' :
418
425
start = pm .find_MAP ()
419
426
cov = pm .find_hessian (point = start )
0 commit comments