Skip to content

Commit aa2da14

Browse files
authored
SVGD problems (#1916)
* fix some svgd problems * switch -> ifelse * except in record
1 parent fed94d3 commit aa2da14

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

pymc3/variational/svgd.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55

66
import numpy as np
77
import theano
8+
from theano.ifelse import ifelse
89
import theano.tensor as tt
910
from tqdm import tqdm
1011
from .updates import adagrad
1112

1213
import pymc3 as pm
1314
from pymc3.model import modelcontext
1415

16+
1517
def rbf_kernel(X):
1618
# TODO. rbf may not be a good choice for high dimension data
1719
XY = tt.dot(X, X.transpose())
@@ -22,11 +24,11 @@ def rbf_kernel(X):
2224
V = H.flatten()
2325

2426
# median distance
25-
h = tt.switch(tt.eq((V.shape[0] % 2), 0),
26-
# if even vector
27-
tt.mean(tt.sort(V)[ ((V.shape[0] // 2) - 1) : ((V.shape[0] // 2) + 1) ]),
28-
# if odd vector
29-
tt.sort(V)[V.shape[0] // 2])
27+
h = ifelse(tt.eq((V.shape[0] % 2), 0),
28+
# if even vector
29+
tt.mean(tt.sort(V)[ ((V.shape[0] // 2) - 1) : ((V.shape[0] // 2) + 1) ]),
30+
# if odd vector
31+
tt.sort(V)[V.shape[0] // 2])
3032

3133
h = tt.sqrt(0.5 * h / tt.log(X.shape[0].astype('float32') + 1.0))
3234

@@ -35,7 +37,7 @@ def rbf_kernel(X):
3537
sumkxy = tt.sum(Kxy, axis=1).dimshuffle(0, 'x')
3638
dxkxy = tt.add(dxkxy, tt.mul(X, sumkxy)) / (h ** 2)
3739

38-
return (Kxy, dxkxy)
40+
return Kxy, dxkxy
3941

4042

4143
def _make_vectorized_logp_grad(vars, model, X):
@@ -70,7 +72,7 @@ def svgd(vars=None, n=5000, n_particles=100, jitter=.01,
7072
random_seed=None, model=None):
7173

7274
if random_seed is not None:
73-
seed(random_seed)
75+
np.random.seed(random_seed)
7476

7577
model = modelcontext(model)
7678
if vars is None:
@@ -102,18 +104,28 @@ def svgd(vars=None, n=5000, n_particles=100, jitter=.01,
102104
else:
103105
progress = np.arange(n)
104106

105-
for ii in progress:
106-
svgd_step(ii)
107+
try:
108+
for ii in progress:
109+
svgd_step(ii)
110+
except KeyboardInterrupt:
111+
pass
112+
finally:
113+
if hasattr(progress, 'close'):
114+
progress.close()
107115

108116
theta_val = theta.get_value()
109117

110118
# Build trace
111-
strace = pm.backends.NDArray()
112-
strace.setup(theta_val.shape[0], 1)
113119

114-
for p in theta_val:
115-
strace.record(model.bijection.rmap(p))
116-
strace.close()
120+
strace = pm.backends.NDArray()
121+
try:
122+
strace.setup(theta_val.shape[0], 1)
123+
for p in theta_val:
124+
strace.record(model.bijection.rmap(p))
125+
except KeyboardInterrupt:
126+
pass
127+
finally:
128+
strace.close()
117129

118130
trace = pm.backends.base.MultiTrace([strace])
119131

0 commit comments

Comments
 (0)