diff --git a/pymc3/variational/svgd.py b/pymc3/variational/svgd.py index 8d4663bd29..bca03e4864 100644 --- a/pymc3/variational/svgd.py +++ b/pymc3/variational/svgd.py @@ -5,6 +5,7 @@ import numpy as np import theano +from theano.ifelse import ifelse import theano.tensor as tt from tqdm import tqdm from .updates import adagrad @@ -12,6 +13,7 @@ import pymc3 as pm from pymc3.model import modelcontext + def rbf_kernel(X): # TODO. rbf may not be a good choice for high dimension data XY = tt.dot(X, X.transpose()) @@ -22,11 +24,11 @@ def rbf_kernel(X): V = H.flatten() # median distance - h = tt.switch(tt.eq((V.shape[0] % 2), 0), - # if even vector - tt.mean(tt.sort(V)[ ((V.shape[0] // 2) - 1) : ((V.shape[0] // 2) + 1) ]), - # if odd vector - tt.sort(V)[V.shape[0] // 2]) + h = ifelse(tt.eq((V.shape[0] % 2), 0), + # if even vector + tt.mean(tt.sort(V)[ ((V.shape[0] // 2) - 1) : ((V.shape[0] // 2) + 1) ]), + # if odd vector + tt.sort(V)[V.shape[0] // 2]) h = tt.sqrt(0.5 * h / tt.log(X.shape[0].astype('float32') + 1.0)) @@ -35,7 +37,7 @@ def rbf_kernel(X): sumkxy = tt.sum(Kxy, axis=1).dimshuffle(0, 'x') dxkxy = tt.add(dxkxy, tt.mul(X, sumkxy)) / (h ** 2) - return (Kxy, dxkxy) + return Kxy, dxkxy def _make_vectorized_logp_grad(vars, model, X): @@ -70,7 +72,7 @@ def svgd(vars=None, n=5000, n_particles=100, jitter=.01, random_seed=None, model=None): if random_seed is not None: - seed(random_seed) + np.random.seed(random_seed) model = modelcontext(model) if vars is None: @@ -102,18 +104,28 @@ def svgd(vars=None, n=5000, n_particles=100, jitter=.01, else: progress = np.arange(n) - for ii in progress: - svgd_step(ii) + try: + for ii in progress: + svgd_step(ii) + except KeyboardInterrupt: + pass + finally: + if hasattr(progress, 'close'): + progress.close() theta_val = theta.get_value() # Build trace - strace = pm.backends.NDArray() - strace.setup(theta_val.shape[0], 1) - for p in theta_val: - strace.record(model.bijection.rmap(p)) - strace.close() + strace = pm.backends.NDArray() + try: + strace.setup(theta_val.shape[0], 1) + for p in theta_val: + strace.record(model.bijection.rmap(p)) + except KeyboardInterrupt: + pass + finally: + strace.close() trace = pm.backends.base.MultiTrace([strace])