11
11
__all__ = ['NUTS' ]
12
12
13
13
14
- def bern (p ):
15
- return nr .uniform () < p
14
+ def logbern (log_p ):
15
+ if np .isnan (log_p ):
16
+ raise FloatingPointError ("log_p can't be nan." )
17
+ return np .log (nr .uniform ()) < log_p
16
18
17
19
18
20
class NUTS (BaseHMC ):
@@ -77,7 +79,8 @@ class NUTS(BaseHMC):
77
79
}]
78
80
79
81
def __init__ (self , vars = None , Emax = 1000 , target_accept = 0.8 ,
80
- gamma = 0.05 , k = 0.75 , t0 = 10 , adapt_step_size = True , ** kwargs ):
82
+ gamma = 0.05 , k = 0.75 , t0 = 10 , adapt_step_size = True ,
83
+ max_treedepth = 10 , ** kwargs ):
81
84
"""
82
85
Parameters
83
86
----------
@@ -122,11 +125,13 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
122
125
self .log_step_size_bar = 0
123
126
self .m = 1
124
127
self .adapt_step_size = adapt_step_size
128
+ self .max_treedepth = max_treedepth
125
129
126
130
self .tune = True
127
131
128
132
def astep (self , q0 ):
129
133
p0 = self .potential .random ()
134
+ v0 = self .compute_velocity (p0 )
130
135
start_energy = self .compute_energy (q0 , p0 )
131
136
132
137
if not self .adapt_step_size :
@@ -136,12 +141,11 @@ def astep(self, q0):
136
141
else :
137
142
step_size = np .exp (self .log_step_size_bar )
138
143
139
- u = nr .uniform ()
140
- start = Edge (q0 , p0 , self .dlogp (q0 ), start_energy )
141
- tree = Tree (self .leapfrog , start , u , step_size , self .Emax )
144
+ start = Edge (q0 , p0 , v0 , self .dlogp (q0 ), start_energy )
145
+ tree = Tree (len (p0 ), self .leapfrog , start , step_size , self .Emax )
142
146
143
- while True :
144
- direction = bern ( 0.5 ) * 2 - 1
147
+ for _ in range ( self . max_treedepth ) :
148
+ direction = logbern ( np . log ( 0.5 ) ) * 2 - 1
145
149
diverging , turning = tree .extend (direction )
146
150
q = tree .proposal .q
147
151
@@ -178,17 +182,17 @@ def competence(var):
178
182
179
183
180
184
# A node in the NUTS tree that is at the far right or left of the tree
181
- Edge = namedtuple ("Edge" , 'q, p, q_grad, energy' )
185
+ Edge = namedtuple ("Edge" , 'q, p, v, q_grad, energy' )
182
186
183
187
# A proposal for the next position
184
188
Proposal = namedtuple ("Proposal" , "q, energy, p_accept" )
185
189
186
- # A subtree of the binary tree build by nuts.
187
- Subtree = namedtuple ("Subtree" , "left, right, proposal, depth, size , accept_sum, n_proposals" )
190
+ # A subtree of the binary tree built by nuts.
191
+ Subtree = namedtuple ("Subtree" , "left, right, p_sum, proposal, log_size , accept_sum, n_proposals" )
188
192
189
193
190
194
class Tree (object ):
191
- def __init__ (self , leapfrog , start , u , step_size , Emax ):
195
+ def __init__ (self , ndim , leapfrog , start , step_size , Emax ):
192
196
"""Binary tree from the NUTS algorithm.
193
197
194
198
Parameters
@@ -197,28 +201,26 @@ def __init__(self, leapfrog, start, u, step_size, Emax):
197
201
A function that performs a single leapfrog step.
198
202
start : Edge
199
203
The starting point of the trajectory.
200
- u : float in [0, 1]
201
- Random slice sampling variable.
202
204
step_size : float
203
205
The step size to use in this tree
204
206
Emax : float
205
207
The maximum energy change to accept before aborting the
206
208
transition as diverging.
207
209
"""
210
+ self .ndim = ndim
208
211
self .leapfrog = leapfrog
209
212
self .start = start
210
- self .log_u = np .log (u )
211
213
self .step_size = step_size
212
214
self .Emax = Emax
213
215
self .start_energy = np .array (start .energy )
214
216
215
217
self .left = self .right = start
216
218
self .proposal = Proposal (start .q , start .energy , 1.0 )
217
219
self .depth = 0
218
- self .size = 1
219
- # TODO Why not a global accept sum and n_proposals?
220
- # self.accept_sum = 0
221
- # self.n_proposals = 0
220
+ self .log_size = 0
221
+ self . accept_sum = 0
222
+ self .n_proposals = 0
223
+ self .p_sum = start . p . copy ()
222
224
self .max_energy_change = 0
223
225
224
226
def extend (self , direction ):
@@ -239,40 +241,46 @@ def extend(self, direction):
239
241
self .right = tree .right
240
242
else :
241
243
tree , diverging , turning = self ._build_subtree (
242
- self .left , self .depth , floatX (np .asarray (- self .step_size )))
244
+ self .left , self .depth , floatX (np .asarray (- self .step_size )))
243
245
self .left = tree .right
244
246
245
- ok = not (diverging or turning )
246
- if ok and bern (min (1 , tree .size / self .size )):
247
+ self .depth += 1
248
+ self .accept_sum += tree .accept_sum
249
+ self .n_proposals += tree .n_proposals
250
+
251
+ if diverging or turning :
252
+ return diverging , turning
253
+
254
+ size1 , size2 = self .log_size , tree .log_size
255
+ if logbern (size2 - size1 ):
247
256
self .proposal = tree .proposal
248
257
249
- self .depth += 1
250
- self .size += tree .size
251
- # TODO why not +=
252
- #self.accept_sum += tree.accept_sum
253
- self .accept_sum = tree .accept_sum
254
- #self.n_proposals += tree.n_proposals
255
- self .n_proposals = tree .n_proposals
258
+ self .log_size = np .logaddexp (self .log_size , tree .log_size )
259
+ self .p_sum [:] += tree .p_sum
256
260
257
261
left , right = self .left , self .right
258
- span = right .q - left .q
259
- turning = turning or (span .dot (left .p ) < 0 ) or (span .dot (right .p ) < 0 )
262
+ p_sum = self .p_sum
263
+ turning = (p_sum .dot (left .v ) <= 0 ) or (p_sum .dot (right .v ) <= 0 )
264
+
260
265
return diverging , turning
261
266
262
267
def _build_subtree (self , left , depth , epsilon ):
263
268
if depth == 0 :
264
269
right = self .leapfrog (left .q , left .p , left .q_grad , epsilon )
265
270
right = Edge (* right )
266
271
energy_change = right .energy - self .start_energy
272
+ if np .isnan (energy_change ):
273
+ energy_change = np .inf
274
+
267
275
if np .abs (energy_change ) > np .abs (self .max_energy_change ):
268
276
self .max_energy_change = energy_change
269
277
p_accept = min (1 , np .exp (- energy_change ))
270
278
271
- size = int ( self . log_u + energy_change <= 0 )
272
- diverging = not ( self . log_u + energy_change < self .Emax )
279
+ log_size = - energy_change
280
+ diverging = energy_change > self .Emax
273
281
274
282
proposal = Proposal (right .q , right .energy , p_accept )
275
- tree = Subtree (right , right , proposal , 1 , size , p_accept , 1 )
283
+ tree = Subtree (right , right , right . p , proposal , log_size , p_accept , 1 )
276
284
return tree , diverging , False
277
285
278
286
tree1 , diverging , turning = self ._build_subtree (left , depth - 1 , epsilon )
@@ -281,20 +289,26 @@ def _build_subtree(self, left, depth, epsilon):
281
289
282
290
tree2 , diverging , turning = self ._build_subtree (tree1 .right , depth - 1 , epsilon )
283
291
284
- size = tree1 .size + tree2 .size
285
- accept_sum = tree1 .accept_sum + tree2 .accept_sum
286
- n_proposals = tree1 .n_proposals + tree2 .n_proposals
287
-
288
292
left , right = tree1 .left , tree2 .right
289
- span = np .sign (epsilon ) * (right .q - left .q )
290
- turning = turning or (span .dot (left .p ) < 0 ) or (span .dot (right .p ) < 0 )
291
293
292
- if bern (tree2 .size * 1. / max (size , 1 )):
293
- proposal = tree2 .proposal
294
+ if not (diverging or turning ):
295
+ p_sum = tree1 .p_sum + tree2 .p_sum
296
+ turning = (p_sum .dot (left .v ) <= 0 ) or (p_sum .dot (right .v ) <= 0 )
297
+
298
+ log_size = np .logaddexp (tree1 .log_size , tree2 .log_size )
299
+ if logbern (tree2 .log_size - log_size ):
300
+ proposal = tree2 .proposal
301
+ else :
302
+ proposal = tree1 .proposal
294
303
else :
304
+ p_sum = tree1 .p_sum
305
+ log_size = tree1 .log_size
295
306
proposal = tree1 .proposal
296
307
297
- tree = Subtree (left , right , proposal , depth , size , accept_sum , n_proposals )
308
+ accept_sum = tree1 .accept_sum + tree2 .accept_sum
309
+ n_proposals = tree1 .n_proposals + tree2 .n_proposals
310
+
311
+ tree = Subtree (left , right , p_sum , proposal , log_size , accept_sum , n_proposals )
298
312
return tree , diverging , turning
299
313
300
314
def stats (self ):
0 commit comments