5
5
6
6
import numpy as np
7
7
import theano
8
+ from theano .ifelse import ifelse
8
9
import theano .tensor as tt
9
10
from tqdm import tqdm
10
11
from .updates import adagrad
11
12
12
13
import pymc3 as pm
13
14
from pymc3 .model import modelcontext
14
15
16
+
15
17
def rbf_kernel (X ):
16
18
# TODO. rbf may not be a good choice for high dimension data
17
19
XY = tt .dot (X , X .transpose ())
@@ -22,11 +24,11 @@ def rbf_kernel(X):
22
24
V = H .flatten ()
23
25
24
26
# median distance
25
- h = tt . switch (tt .eq ((V .shape [0 ] % 2 ), 0 ),
26
- # if even vector
27
- tt .mean (tt .sort (V )[ ((V .shape [0 ] // 2 ) - 1 ) : ((V .shape [0 ] // 2 ) + 1 ) ]),
28
- # if odd vector
29
- tt .sort (V )[V .shape [0 ] // 2 ])
27
+ h = ifelse (tt .eq ((V .shape [0 ] % 2 ), 0 ),
28
+ # if even vector
29
+ tt .mean (tt .sort (V )[ ((V .shape [0 ] // 2 ) - 1 ) : ((V .shape [0 ] // 2 ) + 1 ) ]),
30
+ # if odd vector
31
+ tt .sort (V )[V .shape [0 ] // 2 ])
30
32
31
33
h = tt .sqrt (0.5 * h / tt .log (X .shape [0 ].astype ('float32' ) + 1.0 ))
32
34
@@ -35,7 +37,7 @@ def rbf_kernel(X):
35
37
sumkxy = tt .sum (Kxy , axis = 1 ).dimshuffle (0 , 'x' )
36
38
dxkxy = tt .add (dxkxy , tt .mul (X , sumkxy )) / (h ** 2 )
37
39
38
- return ( Kxy , dxkxy )
40
+ return Kxy , dxkxy
39
41
40
42
41
43
def _make_vectorized_logp_grad (vars , model , X ):
@@ -70,7 +72,7 @@ def svgd(vars=None, n=5000, n_particles=100, jitter=.01,
70
72
random_seed = None , model = None ):
71
73
72
74
if random_seed is not None :
73
- seed (random_seed )
75
+ np . random . seed (random_seed )
74
76
75
77
model = modelcontext (model )
76
78
if vars is None :
@@ -102,18 +104,28 @@ def svgd(vars=None, n=5000, n_particles=100, jitter=.01,
102
104
else :
103
105
progress = np .arange (n )
104
106
105
- for ii in progress :
106
- svgd_step (ii )
107
+ try :
108
+ for ii in progress :
109
+ svgd_step (ii )
110
+ except KeyboardInterrupt :
111
+ pass
112
+ finally :
113
+ if hasattr (progress , 'close' ):
114
+ progress .close ()
107
115
108
116
theta_val = theta .get_value ()
109
117
110
118
# Build trace
111
- strace = pm .backends .NDArray ()
112
- strace .setup (theta_val .shape [0 ], 1 )
113
119
114
- for p in theta_val :
115
- strace .record (model .bijection .rmap (p ))
116
- strace .close ()
120
+ strace = pm .backends .NDArray ()
121
+ try :
122
+ strace .setup (theta_val .shape [0 ], 1 )
123
+ for p in theta_val :
124
+ strace .record (model .bijection .rmap (p ))
125
+ except KeyboardInterrupt :
126
+ pass
127
+ finally :
128
+ strace .close ()
117
129
118
130
trace = pm .backends .base .MultiTrace ([strace ])
119
131
0 commit comments