@@ -79,7 +79,8 @@ class NUTS(BaseHMC):
79
79
}]
80
80
81
81
def __init__ (self , vars = None , Emax = 1000 , target_accept = 0.8 ,
82
- gamma = 0.05 , k = 0.75 , t0 = 10 , adapt_step_size = True , ** kwargs ):
82
+ gamma = 0.05 , k = 0.75 , t0 = 10 , adapt_step_size = True ,
83
+ max_treedepth = 10 , ** kwargs ):
83
84
"""
84
85
Parameters
85
86
----------
@@ -124,11 +125,13 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
124
125
self .log_step_size_bar = 0
125
126
self .m = 1
126
127
self .adapt_step_size = adapt_step_size
128
+ self .max_treedepth = max_treedepth
127
129
128
130
self .tune = True
129
131
130
132
def astep (self , q0 ):
131
133
p0 = self .potential .random ()
134
+ v0 = self .compute_velocity (p0 )
132
135
start_energy = self .compute_energy (q0 , p0 )
133
136
134
137
if not self .adapt_step_size :
@@ -138,10 +141,10 @@ def astep(self, q0):
138
141
else :
139
142
step_size = np .exp (self .log_step_size_bar )
140
143
141
- start = Edge (q0 , p0 , self .dlogp (q0 ), start_energy )
142
- tree = Tree (self .leapfrog , start , step_size , self .Emax )
144
+ start = Edge (q0 , p0 , v0 , self .dlogp (q0 ), start_energy )
145
+ tree = Tree (len ( p0 ), self .leapfrog , start , step_size , self .Emax )
143
146
144
- while True :
147
+ for _ in range ( self . max_treedepth ) :
145
148
direction = logbern (np .log (0.5 )) * 2 - 1
146
149
diverging , turning = tree .extend (direction )
147
150
q = tree .proposal .q
@@ -179,17 +182,17 @@ def competence(var):
179
182
180
183
181
184
# A node in the NUTS tree that is at the far right or left of the tree
182
- Edge = namedtuple ("Edge" , 'q, p, q_grad, energy' )
185
+ Edge = namedtuple ("Edge" , 'q, p, v, q_grad, energy' )
183
186
184
187
# A proposal for the next position
185
188
Proposal = namedtuple ("Proposal" , "q, energy, p_accept" )
186
189
187
190
# A subtree of the binary tree build by nuts.
188
- Subtree = namedtuple ("Subtree" , "left, right, proposal, depth , log_size, accept_sum, n_proposals" )
191
+ Subtree = namedtuple ("Subtree" , "left, right, p_sum, proposal , log_size, accept_sum, n_proposals" )
189
192
190
193
191
194
class Tree (object ):
192
- def __init__ (self , leapfrog , start , step_size , Emax ):
195
+ def __init__ (self , ndim , leapfrog , start , step_size , Emax ):
193
196
"""Binary tree from the NUTS algorithm.
194
197
195
198
Parameters
@@ -204,6 +207,7 @@ def __init__(self, leapfrog, start, step_size, Emax):
204
207
The maximum energy change to accept before aborting the
205
208
transition as diverging.
206
209
"""
210
+ self .ndim = ndim
207
211
self .leapfrog = leapfrog
208
212
self .start = start
209
213
self .step_size = step_size
@@ -214,9 +218,9 @@ def __init__(self, leapfrog, start, step_size, Emax):
214
218
self .proposal = Proposal (start .q , start .energy , 1.0 )
215
219
self .depth = 0
216
220
self .log_size = 0
217
- # TODO Why not a global accept sum and n_proposals?
218
- # self.accept_sum = 0
219
- # self.n_proposals = 0
221
+ self . accept_sum = 0
222
+ self .n_proposals = 0
223
+ self .p_sum = start . p . copy ()
220
224
self .max_energy_change = 0
221
225
222
226
def extend (self , direction ):
@@ -237,63 +241,75 @@ def extend(self, direction):
237
241
self .right = tree .right
238
242
else :
239
243
tree , diverging , turning = self ._build_subtree (
240
- self .left , self .depth , floatX (np .asarray (- self .step_size )))
244
+ self .left , self .depth , floatX (np .asarray (- self .step_size )))
241
245
self .left = tree .right
242
246
243
- ok = not (diverging or turning )
244
- if ok and logbern (tree .log_size - self .log_size ):
247
+ self .depth += 1
248
+ self .accept_sum += tree .accept_sum
249
+ self .n_proposals += tree .n_proposals
250
+
251
+ if diverging or turning :
252
+ return diverging , turning
253
+
254
+ size1 , size2 = self .log_size , tree .log_size
255
+ if logbern (size2 - size1 ):
245
256
self .proposal = tree .proposal
246
257
247
- self .depth += 1
248
258
self .log_size = np .logaddexp (self .log_size , tree .log_size )
249
- # TODO why not +=
250
- #self.accept_sum += tree.accept_sum
251
- self .accept_sum = tree .accept_sum
252
- #self.n_proposals += tree.n_proposals
253
- self .n_proposals = tree .n_proposals
259
+ self .p_sum [:] += tree .p_sum
254
260
255
261
left , right = self .left , self .right
256
- span = right .q - left .q
257
- turning = turning or (span .dot (left .p ) < 0 ) or (span .dot (right .p ) < 0 )
262
+ p_sum = self .p_sum
263
+ turning = (p_sum .dot (left .v ) <= 0 ) or (p_sum .dot (right .v ) <= 0 )
264
+
258
265
return diverging , turning
259
266
260
267
def _build_subtree (self , left , depth , epsilon ):
261
268
if depth == 0 :
262
269
right = self .leapfrog (left .q , left .p , left .q_grad , epsilon )
263
270
right = Edge (* right )
264
271
energy_change = right .energy - self .start_energy
272
+ if np .isnan (energy_change ):
273
+ energy_change = np .inf
274
+
265
275
if np .abs (energy_change ) > np .abs (self .max_energy_change ):
266
276
self .max_energy_change = energy_change
267
277
p_accept = min (1 , np .exp (- energy_change ))
268
278
269
279
log_size = - energy_change
270
- diverging = not np .isfinite (energy_change )
271
- diverging = diverging or (np .abs (energy_change ) > self .Emax )
280
+ diverging = energy_change > self .Emax
272
281
273
282
proposal = Proposal (right .q , right .energy , p_accept )
274
- tree = Subtree (right , right , proposal , 1 , log_size , p_accept , 1 )
283
+ tree = Subtree (right , right , right . p , proposal , log_size , p_accept , 1 )
275
284
return tree , diverging , False
276
285
277
286
tree1 , diverging , turning = self ._build_subtree (left , depth - 1 , epsilon )
278
287
if diverging or turning :
279
288
return tree1 , diverging , turning
280
289
281
290
tree2 , diverging , turning = self ._build_subtree (tree1 .right , depth - 1 , epsilon )
291
+ ok = not (diverging or turning )
282
292
283
- log_size = np .logaddexp (tree1 .log_size , tree2 .log_size )
284
293
accept_sum = tree1 .accept_sum + tree2 .accept_sum
285
294
n_proposals = tree1 .n_proposals + tree2 .n_proposals
286
295
287
296
left , right = tree1 .left , tree2 .right
288
- span = np .sign (epsilon ) * (right .q - left .q )
289
- turning = turning or (span .dot (left .p ) < 0 ) or (span .dot (right .p ) < 0 )
290
297
291
- if np .isfinite (tree2 .log_size ) and logbern (tree2 .log_size - log_size ):
292
- proposal = tree2 .proposal
298
+ if ok :
299
+ p_sum = tree1 .p_sum + tree2 .p_sum
300
+ turning = (p_sum .dot (left .v ) <= 0 ) or (p_sum .dot (right .v ) <= 0 )
301
+
302
+ log_size = np .logaddexp (tree1 .log_size , tree2 .log_size )
303
+ if logbern (tree2 .log_size - log_size ):
304
+ proposal = tree2 .proposal
305
+ else :
306
+ proposal = tree1 .proposal
293
307
else :
308
+ p_sum = tree1 .p_sum
309
+ log_size = tree1 .log_size
294
310
proposal = tree1 .proposal
295
311
296
- tree = Subtree (left , right , proposal , depth , log_size , accept_sum , n_proposals )
312
+ tree = Subtree (left , right , p_sum , proposal , log_size , accept_sum , n_proposals )
297
313
return tree , diverging , turning
298
314
299
315
def stats (self ):
0 commit comments