Skip to content

Commit 209d35f

Browse files
committed
Use sum of momentum for nuts termination
1 parent 14bef5c commit 209d35f

File tree

8 files changed

+130
-76
lines changed

8 files changed

+130
-76
lines changed

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False,
5959
if theano_kwargs is None:
6060
theano_kwargs = {}
6161

62-
self.H, self.compute_energy, self.leapfrog, self.dlogp = get_theano_hamiltonian_functions(
62+
self.H, self.compute_energy, self.compute_velocity, self.leapfrog, self.dlogp = get_theano_hamiltonian_functions(
6363
vars, shared, model.logpt, self.potential, use_single_leapfrog, integrator, **theano_kwargs)
6464

6565
super(BaseHMC, self).__init__(vars, shared, blocked=blocked)

pymc3/step_methods/hmc/nuts.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ class NUTS(BaseHMC):
7979
}]
8080

8181
def __init__(self, vars=None, Emax=1000, target_accept=0.8,
82-
gamma=0.05, k=0.75, t0=10, adapt_step_size=True, **kwargs):
82+
gamma=0.05, k=0.75, t0=10, adapt_step_size=True,
83+
max_treedepth=10, **kwargs):
8384
"""
8485
Parameters
8586
----------
@@ -124,11 +125,13 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
124125
self.log_step_size_bar = 0
125126
self.m = 1
126127
self.adapt_step_size = adapt_step_size
128+
self.max_treedepth = max_treedepth
127129

128130
self.tune = True
129131

130132
def astep(self, q0):
131133
p0 = self.potential.random()
134+
v0 = self.compute_velocity(p0)
132135
start_energy = self.compute_energy(q0, p0)
133136

134137
if not self.adapt_step_size:
@@ -138,10 +141,10 @@ def astep(self, q0):
138141
else:
139142
step_size = np.exp(self.log_step_size_bar)
140143

141-
start = Edge(q0, p0, self.dlogp(q0), start_energy)
142-
tree = Tree(self.leapfrog, start, step_size, self.Emax)
144+
start = Edge(q0, p0, v0, self.dlogp(q0), start_energy)
145+
tree = Tree(len(p0), self.leapfrog, start, step_size, self.Emax)
143146

144-
while True:
147+
for _ in range(self.max_treedepth):
145148
direction = logbern(np.log(0.5)) * 2 - 1
146149
diverging, turning = tree.extend(direction)
147150
q = tree.proposal.q
@@ -179,17 +182,17 @@ def competence(var):
179182

180183

181184
# A node in the NUTS tree that is at the far right or left of the tree
182-
Edge = namedtuple("Edge", 'q, p, q_grad, energy')
185+
Edge = namedtuple("Edge", 'q, p, v, q_grad, energy')
183186

184187
# A proposal for the next position
185188
Proposal = namedtuple("Proposal", "q, energy, p_accept")
186189

187190
# A subtree of the binary tree build by nuts.
188-
Subtree = namedtuple("Subtree", "left, right, proposal, depth, log_size, accept_sum, n_proposals")
191+
Subtree = namedtuple("Subtree", "left, right, p_sum, proposal, log_size, accept_sum, n_proposals")
189192

190193

191194
class Tree(object):
192-
def __init__(self, leapfrog, start, step_size, Emax):
195+
def __init__(self, ndim, leapfrog, start, step_size, Emax):
193196
"""Binary tree from the NUTS algorithm.
194197
195198
Parameters
@@ -204,6 +207,7 @@ def __init__(self, leapfrog, start, step_size, Emax):
204207
The maximum energy change to accept before aborting the
205208
transition as diverging.
206209
"""
210+
self.ndim = ndim
207211
self.leapfrog = leapfrog
208212
self.start = start
209213
self.step_size = step_size
@@ -214,9 +218,9 @@ def __init__(self, leapfrog, start, step_size, Emax):
214218
self.proposal = Proposal(start.q, start.energy, 1.0)
215219
self.depth = 0
216220
self.log_size = 0
217-
# TODO Why not a global accept sum and n_proposals?
218-
#self.accept_sum = 0
219-
#self.n_proposals = 0
221+
self.accept_sum = 0
222+
self.n_proposals = 0
223+
self.p_sum = start.p.copy()
220224
self.max_energy_change = 0
221225

222226
def extend(self, direction):
@@ -237,63 +241,75 @@ def extend(self, direction):
237241
self.right = tree.right
238242
else:
239243
tree, diverging, turning = self._build_subtree(
240-
self.left, self.depth, floatX(np.asarray(- self.step_size)))
244+
self.left, self.depth, floatX(np.asarray(-self.step_size)))
241245
self.left = tree.right
242246

243-
ok = not (diverging or turning)
244-
if ok and logbern(tree.log_size - self.log_size):
247+
self.depth += 1
248+
self.accept_sum += tree.accept_sum
249+
self.n_proposals += tree.n_proposals
250+
251+
if diverging or turning:
252+
return diverging, turning
253+
254+
size1, size2 = self.log_size, tree.log_size
255+
if logbern(size2 - size1):
245256
self.proposal = tree.proposal
246257

247-
self.depth += 1
248258
self.log_size = np.logaddexp(self.log_size, tree.log_size)
249-
# TODO why not +=
250-
#self.accept_sum += tree.accept_sum
251-
self.accept_sum = tree.accept_sum
252-
#self.n_proposals += tree.n_proposals
253-
self.n_proposals = tree.n_proposals
259+
self.p_sum[:] += tree.p_sum
254260

255261
left, right = self.left, self.right
256-
span = right.q - left.q
257-
turning = turning or (span.dot(left.p) < 0) or (span.dot(right.p) < 0)
262+
p_sum = self.p_sum
263+
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
264+
258265
return diverging, turning
259266

260267
def _build_subtree(self, left, depth, epsilon):
261268
if depth == 0:
262269
right = self.leapfrog(left.q, left.p, left.q_grad, epsilon)
263270
right = Edge(*right)
264271
energy_change = right.energy - self.start_energy
272+
if np.isnan(energy_change):
273+
energy_change = np.inf
274+
265275
if np.abs(energy_change) > np.abs(self.max_energy_change):
266276
self.max_energy_change = energy_change
267277
p_accept = min(1, np.exp(-energy_change))
268278

269279
log_size = -energy_change
270-
diverging = not np.isfinite(energy_change)
271-
diverging = diverging or (np.abs(energy_change) > self.Emax)
280+
diverging = energy_change > self.Emax
272281

273282
proposal = Proposal(right.q, right.energy, p_accept)
274-
tree = Subtree(right, right, proposal, 1, log_size, p_accept, 1)
283+
tree = Subtree(right, right, right.p, proposal, log_size, p_accept, 1)
275284
return tree, diverging, False
276285

277286
tree1, diverging, turning = self._build_subtree(left, depth - 1, epsilon)
278287
if diverging or turning:
279288
return tree1, diverging, turning
280289

281290
tree2, diverging, turning = self._build_subtree(tree1.right, depth - 1, epsilon)
291+
ok = not (diverging or turning)
282292

283-
log_size = np.logaddexp(tree1.log_size, tree2.log_size)
284293
accept_sum = tree1.accept_sum + tree2.accept_sum
285294
n_proposals = tree1.n_proposals + tree2.n_proposals
286295

287296
left, right = tree1.left, tree2.right
288-
span = np.sign(epsilon) * (right.q - left.q)
289-
turning = turning or (span.dot(left.p) < 0) or (span.dot(right.p) < 0)
290297

291-
if np.isfinite(tree2.log_size) and logbern(tree2.log_size - log_size):
292-
proposal = tree2.proposal
298+
if ok:
299+
p_sum = tree1.p_sum + tree2.p_sum
300+
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
301+
302+
log_size = np.logaddexp(tree1.log_size, tree2.log_size)
303+
if logbern(tree2.log_size - log_size):
304+
proposal = tree2.proposal
305+
else:
306+
proposal = tree1.proposal
293307
else:
308+
p_sum = tree1.p_sum
309+
log_size = tree1.log_size
294310
proposal = tree1.proposal
295311

296-
tree = Subtree(left, right, proposal, depth, log_size, accept_sum, n_proposals)
312+
tree = Subtree(left, right, p_sum, proposal, log_size, accept_sum, n_proposals)
297313
return tree, diverging, turning
298314

299315
def stats(self):

pymc3/step_methods/hmc/quadpotential.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from numpy import dot
22
from numpy.random import normal
33
import scipy.linalg
4+
import theano.tensor as tt
45
from theano.tensor import slinalg
56
from scipy.sparse import issparse
67

@@ -123,7 +124,7 @@ def __init__(self, A):
123124
self.L = scipy.linalg.cholesky(A, lower=True)
124125

125126
def velocity(self, x):
126-
return x.T.dot(self.A.T)
127+
return tt.dot(self.A, x)
127128

128129
def random(self):
129130
n = floatX(normal(size=self.L.shape[0]))

pymc3/step_methods/hmc/trajectory.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def _theano_energy_function(H, q, **theano_kwargs):
5757
return energy_function, p
5858

5959

60+
def _theano_velocity_function(H, p, **theano_kwargs):
61+
v = H.pot.velocity(p)
62+
velocity_function = theano.function(inputs=[p], outputs=v, **theano_kwargs)
63+
velocity_function.trust_input = True
64+
return velocity_function
65+
66+
6067
def _theano_leapfrog_integrator(H, q, p, **theano_kwargs):
6168
"""Computes a theano function that computes one leapfrog step and the energy at the
6269
end of the trajectory.
@@ -115,6 +122,7 @@ def get_theano_hamiltonian_functions(model_vars, shared, logpt, potential,
115122
"""
116123
H, q, dlogp = _theano_hamiltonian(model_vars, shared, logpt, potential)
117124
energy_function, p = _theano_energy_function(H, q, **theano_kwargs)
125+
velocity_function = _theano_velocity_function(H, p, **theano_kwargs)
118126
if use_single_leapfrog:
119127
try:
120128
_theano_integrator = INTEGRATORS_SINGLE[integrator]
@@ -125,7 +133,7 @@ def get_theano_hamiltonian_functions(model_vars, shared, logpt, potential,
125133
if integrator != "leapfrog":
126134
raise ValueError("Only leapfrog is supported")
127135
integrator = _theano_leapfrog_integrator(H, q, p, **theano_kwargs)
128-
return H, energy_function, integrator, dlogp
136+
return H, energy_function, velocity_function, integrator, dlogp
129137

130138

131139
def energy(H, q, p):
@@ -214,11 +222,12 @@ def _theano_single_threestage(H, q, p, q_grad, **theano_kwargs):
214222
q_e = q_1be + floatX(b) * epsilon * H.pot.velocity(p_1ae)
215223
grad_e = H.dlogp(q_e)
216224
p_e = p_1ae + floatX(a) * epsilon * grad_e
225+
v_e = H.pot.velocity(p_e)
217226

218227
new_energy = energy(H, q_e, p_e)
219228

220229
f = theano.function(inputs=[q, p, q_grad, epsilon],
221-
outputs=[q_e, p_e, grad_e, new_energy],
230+
outputs=[q_e, p_e, v_e, grad_e, new_energy],
222231
**theano_kwargs)
223232
f.trust_input = True
224233
return f
@@ -250,10 +259,11 @@ def _theano_single_twostage(H, q, p, q_grad, **theano_kwargs):
250259
q_e = q_e2 + epsilon / 2 * H.pot.velocity(p_1ae)
251260
grad_e = H.dlogp(q_e)
252261
p_e = p_1ae + a * epsilon * grad_e
262+
v_e = H.pot.velocity(p_e)
253263

254264
new_energy = energy(H, q_e, p_e)
255265
f = theano.function(inputs=[q, p, q_grad, epsilon],
256-
outputs=[q_e, p_e, grad_e, new_energy],
266+
outputs=[q_e, p_e, v_e, grad_e, new_energy],
257267
**theano_kwargs)
258268
f.trust_input = True
259269
return f
@@ -273,15 +283,47 @@ def _theano_single_leapfrog(H, q, p, q_grad, **theano_kwargs):
273283
q_new_grad = H.dlogp(q_new)
274284
p_new += 0.5 * epsilon * q_new_grad # half momentum update
275285
energy_new = energy(H, q_new, p_new)
286+
v_new = H.pot.velocity(p_new)
276287

277288
f = theano.function(inputs=[q, p, q_grad, epsilon],
278-
outputs=[q_new, p_new, q_new_grad, energy_new], **theano_kwargs)
289+
outputs=[q_new, p_new, v_new, q_new_grad, energy_new],
290+
**theano_kwargs)
291+
f.trust_input = True
292+
return f
293+
294+
295+
def _theano_single_leapfrog3(H, q, p, q_grad, **theano_kwargs):
296+
"""Do three leapfrog steps."""
297+
step_size = tt.scalar('epsilon')
298+
step_size.tag.test_value = 1.
299+
300+
epsilon = step_size / 3
301+
302+
p_new = p + 0.5 * epsilon * q_grad # half momentum update
303+
q_new = q + epsilon * H.pot.velocity(p_new) # full position update
304+
305+
p_new = p_new + epsilon * H.dlogp(q_new)
306+
q_new = q_new + epsilon * H.pot.velocity(p_new)
307+
308+
p_new = p_new + epsilon * H.dlogp(q_new)
309+
q_new = q_new + epsilon * H.pot.velocity(p_new)
310+
311+
q_new_grad = H.dlogp(q_new)
312+
p_new = p_new + 0.5 * epsilon * q_new_grad
313+
314+
energy_new = energy(H, q_new, p_new)
315+
v_new = H.pot.velocity(p_new)
316+
317+
f = theano.function(inputs=[q, p, q_grad, step_size],
318+
outputs=[q_new, p_new, v_new, q_new_grad, energy_new],
319+
**theano_kwargs)
279320
f.trust_input = True
280321
return f
281322

282323

283324
INTEGRATORS_SINGLE = {
284325
'leapfrog': _theano_single_leapfrog,
285326
'two-stage': _theano_single_twostage,
286-
'three-stage': _theano_single_threestage
327+
'three-stage': _theano_single_threestage,
328+
'leapfrog3': _theano_single_leapfrog3,
287329
}

pymc3/tests/sampler_fixtures.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,6 @@ def make_model(cls):
5757
a = pm.Uniform("a", lower=-1, upper=1)
5858
return model
5959

60-
def test_interval(self):
61-
a = self.samples['a']
62-
npt.assert_almost_equal(((a > 0.1) & (a < 0.5)).mean(), 0.2, 2)
63-
6460

6561
class NormalFixture(KnownMean, KnownVariance, KnownCDF):
6662
means = {'a': 2 * np.ones(10)}
@@ -86,9 +82,8 @@ def make_model(cls):
8682
return model
8783

8884

89-
class StudentTFixture(KnownMean, KnownVariance, KnownCDF):
85+
class StudentTFixture(KnownMean, KnownCDF):
9086
means = {'a': 0}
91-
variances = {'a': 3}
9287
cdfs = {'a': stats.t(df=3).cdf}
9388
ks_thin = 10
9489

pymc3/tests/test_hmc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_leapfrog_reversible_single():
3030
n = 3
3131
start, model, _ = models.non_normal(n)
3232

33-
integrators = ['leapfrog', 'two-stage', 'three-stage']
33+
integrators = ['leapfrog', 'two-stage', 'three-stage', 'leapfrog3']
3434
steps = [BaseHMC(vars=model.vars, model=model, integrator=method, use_single_leapfrog=True)
3535
for method in integrators]
3636
for method, step in zip(integrators, steps):
@@ -46,10 +46,10 @@ def test_leapfrog_reversible_single():
4646

4747
energy = step.compute_energy(q, p)
4848
for _ in range(n_steps):
49-
q, p, dlogp, _ = step.leapfrog(q, p, dlogp, np.array(epsilon))
49+
q, p, v, dlogp, _ = step.leapfrog(q, p, dlogp, np.array(epsilon))
5050
p = -p
5151
for _ in range(n_steps):
52-
q, p, dlogp, _ = step.leapfrog(q, p, dlogp, np.array(epsilon))
52+
q, p, v, dlogp, _ = step.leapfrog(q, p, dlogp, np.array(epsilon))
5353

5454
close_to(q, q0, 1e-8, str(('q', method, n_steps, epsilon)))
5555
close_to(-p, p0, 1e-8, str(('p', method, n_steps, epsilon)))

pymc3/tests/test_posteriors.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,12 @@ class NUTSNormal(sf.NutsFixture, sf.NormalFixture):
6262

6363

6464
class NUTSBetaBinomial(sf.NutsFixture, sf.BetaBinomialFixture):
65-
n_samples = 10000
65+
n_samples = 2000
66+
ks_thin = 5
6667
tune = 1000
6768
burn = 1000
6869
chains = 2
69-
min_n_eff = 2000
70-
rtol = 0.1
71-
atol = 0.05
70+
min_n_eff = 400
7271

7372

7473
@attr('extra')

0 commit comments

Comments
 (0)