Skip to content

Commit f73d2ff

Browse files
authored
Merge pull request #1769 from aseyboldt/multinomial-nuts
[WIP] Multinomial sampling for nuts
2 parents 8ad1af2 + de8ddf3 commit f73d2ff

File tree

8 files changed

+109
-83
lines changed

8 files changed

+109
-83
lines changed

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False,
5959
if theano_kwargs is None:
6060
theano_kwargs = {}
6161

62-
self.H, self.compute_energy, self.leapfrog, self.dlogp = get_theano_hamiltonian_functions(
62+
self.H, self.compute_energy, self.compute_velocity, self.leapfrog, self.dlogp = get_theano_hamiltonian_functions(
6363
vars, shared, model.logpt, self.potential, use_single_leapfrog, integrator, **theano_kwargs)
6464

6565
super(BaseHMC, self).__init__(vars, shared, blocked=blocked)

pymc3/step_methods/hmc/nuts.py

Lines changed: 57 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
__all__ = ['NUTS']
1212

1313

14-
def bern(p):
15-
return nr.uniform() < p
14+
def logbern(log_p):
15+
if np.isnan(log_p):
16+
raise FloatingPointError("log_p can't be nan.")
17+
return np.log(nr.uniform()) < log_p
1618

1719

1820
class NUTS(BaseHMC):
@@ -77,7 +79,8 @@ class NUTS(BaseHMC):
7779
}]
7880

7981
def __init__(self, vars=None, Emax=1000, target_accept=0.8,
80-
gamma=0.05, k=0.75, t0=10, adapt_step_size=True, **kwargs):
82+
gamma=0.05, k=0.75, t0=10, adapt_step_size=True,
83+
max_treedepth=10, **kwargs):
8184
"""
8285
Parameters
8386
----------
@@ -122,11 +125,13 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
122125
self.log_step_size_bar = 0
123126
self.m = 1
124127
self.adapt_step_size = adapt_step_size
128+
self.max_treedepth = max_treedepth
125129

126130
self.tune = True
127131

128132
def astep(self, q0):
129133
p0 = self.potential.random()
134+
v0 = self.compute_velocity(p0)
130135
start_energy = self.compute_energy(q0, p0)
131136

132137
if not self.adapt_step_size:
@@ -136,12 +141,11 @@ def astep(self, q0):
136141
else:
137142
step_size = np.exp(self.log_step_size_bar)
138143

139-
u = nr.uniform()
140-
start = Edge(q0, p0, self.dlogp(q0), start_energy)
141-
tree = Tree(self.leapfrog, start, u, step_size, self.Emax)
144+
start = Edge(q0, p0, v0, self.dlogp(q0), start_energy)
145+
tree = Tree(len(p0), self.leapfrog, start, step_size, self.Emax)
142146

143-
while True:
144-
direction = bern(0.5) * 2 - 1
147+
for _ in range(self.max_treedepth):
148+
direction = logbern(np.log(0.5)) * 2 - 1
145149
diverging, turning = tree.extend(direction)
146150
q = tree.proposal.q
147151

@@ -178,17 +182,17 @@ def competence(var):
178182

179183

180184
# A node in the NUTS tree that is at the far right or left of the tree
181-
Edge = namedtuple("Edge", 'q, p, q_grad, energy')
185+
Edge = namedtuple("Edge", 'q, p, v, q_grad, energy')
182186

183187
# A proposal for the next position
184188
Proposal = namedtuple("Proposal", "q, energy, p_accept")
185189

186-
# A subtree of the binary tree build by nuts.
187-
Subtree = namedtuple("Subtree", "left, right, proposal, depth, size, accept_sum, n_proposals")
190+
# A subtree of the binary tree built by nuts.
191+
Subtree = namedtuple("Subtree", "left, right, p_sum, proposal, log_size, accept_sum, n_proposals")
188192

189193

190194
class Tree(object):
191-
def __init__(self, leapfrog, start, u, step_size, Emax):
195+
def __init__(self, ndim, leapfrog, start, step_size, Emax):
192196
"""Binary tree from the NUTS algorithm.
193197
194198
Parameters
@@ -197,28 +201,26 @@ def __init__(self, leapfrog, start, u, step_size, Emax):
197201
A function that performs a single leapfrog step.
198202
start : Edge
199203
The starting point of the trajectory.
200-
u : float in [0, 1]
201-
Random slice sampling variable.
202204
step_size : float
203205
The step size to use in this tree
204206
Emax : float
205207
The maximum energy change to accept before aborting the
206208
transition as diverging.
207209
"""
210+
self.ndim = ndim
208211
self.leapfrog = leapfrog
209212
self.start = start
210-
self.log_u = np.log(u)
211213
self.step_size = step_size
212214
self.Emax = Emax
213215
self.start_energy = np.array(start.energy)
214216

215217
self.left = self.right = start
216218
self.proposal = Proposal(start.q, start.energy, 1.0)
217219
self.depth = 0
218-
self.size = 1
219-
# TODO Why not a global accept sum and n_proposals?
220-
#self.accept_sum = 0
221-
#self.n_proposals = 0
220+
self.log_size = 0
221+
self.accept_sum = 0
222+
self.n_proposals = 0
223+
self.p_sum = start.p.copy()
222224
self.max_energy_change = 0
223225

224226
def extend(self, direction):
@@ -239,40 +241,46 @@ def extend(self, direction):
239241
self.right = tree.right
240242
else:
241243
tree, diverging, turning = self._build_subtree(
242-
self.left, self.depth, floatX(np.asarray(- self.step_size)))
244+
self.left, self.depth, floatX(np.asarray(-self.step_size)))
243245
self.left = tree.right
244246

245-
ok = not (diverging or turning)
246-
if ok and bern(min(1, tree.size / self.size)):
247+
self.depth += 1
248+
self.accept_sum += tree.accept_sum
249+
self.n_proposals += tree.n_proposals
250+
251+
if diverging or turning:
252+
return diverging, turning
253+
254+
size1, size2 = self.log_size, tree.log_size
255+
if logbern(size2 - size1):
247256
self.proposal = tree.proposal
248257

249-
self.depth += 1
250-
self.size += tree.size
251-
# TODO why not +=
252-
#self.accept_sum += tree.accept_sum
253-
self.accept_sum = tree.accept_sum
254-
#self.n_proposals += tree.n_proposals
255-
self.n_proposals = tree.n_proposals
258+
self.log_size = np.logaddexp(self.log_size, tree.log_size)
259+
self.p_sum[:] += tree.p_sum
256260

257261
left, right = self.left, self.right
258-
span = right.q - left.q
259-
turning = turning or (span.dot(left.p) < 0) or (span.dot(right.p) < 0)
262+
p_sum = self.p_sum
263+
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
264+
260265
return diverging, turning
261266

262267
def _build_subtree(self, left, depth, epsilon):
263268
if depth == 0:
264269
right = self.leapfrog(left.q, left.p, left.q_grad, epsilon)
265270
right = Edge(*right)
266271
energy_change = right.energy - self.start_energy
272+
if np.isnan(energy_change):
273+
energy_change = np.inf
274+
267275
if np.abs(energy_change) > np.abs(self.max_energy_change):
268276
self.max_energy_change = energy_change
269277
p_accept = min(1, np.exp(-energy_change))
270278

271-
size = int(self.log_u + energy_change <= 0)
272-
diverging = not (self.log_u + energy_change < self.Emax)
279+
log_size = -energy_change
280+
diverging = energy_change > self.Emax
273281

274282
proposal = Proposal(right.q, right.energy, p_accept)
275-
tree = Subtree(right, right, proposal, 1, size, p_accept, 1)
283+
tree = Subtree(right, right, right.p, proposal, log_size, p_accept, 1)
276284
return tree, diverging, False
277285

278286
tree1, diverging, turning = self._build_subtree(left, depth - 1, epsilon)
@@ -281,20 +289,26 @@ def _build_subtree(self, left, depth, epsilon):
281289

282290
tree2, diverging, turning = self._build_subtree(tree1.right, depth - 1, epsilon)
283291

284-
size = tree1.size + tree2.size
285-
accept_sum = tree1.accept_sum + tree2.accept_sum
286-
n_proposals = tree1.n_proposals + tree2.n_proposals
287-
288292
left, right = tree1.left, tree2.right
289-
span = np.sign(epsilon) * (right.q - left.q)
290-
turning = turning or (span.dot(left.p) < 0) or (span.dot(right.p) < 0)
291293

292-
if bern(tree2.size * 1. / max(size, 1)):
293-
proposal = tree2.proposal
294+
if not (diverging or turning):
295+
p_sum = tree1.p_sum + tree2.p_sum
296+
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
297+
298+
log_size = np.logaddexp(tree1.log_size, tree2.log_size)
299+
if logbern(tree2.log_size - log_size):
300+
proposal = tree2.proposal
301+
else:
302+
proposal = tree1.proposal
294303
else:
304+
p_sum = tree1.p_sum
305+
log_size = tree1.log_size
295306
proposal = tree1.proposal
296307

297-
tree = Subtree(left, right, proposal, depth, size, accept_sum, n_proposals)
308+
accept_sum = tree1.accept_sum + tree2.accept_sum
309+
n_proposals = tree1.n_proposals + tree2.n_proposals
310+
311+
tree = Subtree(left, right, p_sum, proposal, log_size, accept_sum, n_proposals)
298312
return tree, diverging, turning
299313

300314
def stats(self):

pymc3/step_methods/hmc/quadpotential.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from numpy import dot
22
from numpy.random import normal
33
import scipy.linalg
4+
import theano.tensor as tt
45
from theano.tensor import slinalg
56
from scipy.sparse import issparse
67

@@ -123,7 +124,7 @@ def __init__(self, A):
123124
self.L = scipy.linalg.cholesky(A, lower=True)
124125

125126
def velocity(self, x):
126-
return x.T.dot(self.A.T)
127+
return tt.dot(self.A, x)
127128

128129
def random(self):
129130
n = floatX(normal(size=self.L.shape[0]))

pymc3/step_methods/hmc/trajectory.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def _theano_energy_function(H, q, **theano_kwargs):
5757
return energy_function, p
5858

5959

60+
def _theano_velocity_function(H, p, **theano_kwargs):
61+
v = H.pot.velocity(p)
62+
velocity_function = theano.function(inputs=[p], outputs=v, **theano_kwargs)
63+
velocity_function.trust_input = True
64+
return velocity_function
65+
66+
6067
def _theano_leapfrog_integrator(H, q, p, **theano_kwargs):
6168
"""Computes a theano function that computes one leapfrog step and the energy at the
6269
end of the trajectory.
@@ -115,6 +122,7 @@ def get_theano_hamiltonian_functions(model_vars, shared, logpt, potential,
115122
"""
116123
H, q, dlogp = _theano_hamiltonian(model_vars, shared, logpt, potential)
117124
energy_function, p = _theano_energy_function(H, q, **theano_kwargs)
125+
velocity_function = _theano_velocity_function(H, p, **theano_kwargs)
118126
if use_single_leapfrog:
119127
try:
120128
_theano_integrator = INTEGRATORS_SINGLE[integrator]
@@ -125,7 +133,7 @@ def get_theano_hamiltonian_functions(model_vars, shared, logpt, potential,
125133
if integrator != "leapfrog":
126134
raise ValueError("Only leapfrog is supported")
127135
integrator = _theano_leapfrog_integrator(H, q, p, **theano_kwargs)
128-
return H, energy_function, integrator, dlogp
136+
return H, energy_function, velocity_function, integrator, dlogp
129137

130138

131139
def energy(H, q, p):
@@ -214,11 +222,12 @@ def _theano_single_threestage(H, q, p, q_grad, **theano_kwargs):
214222
q_e = q_1be + floatX(b) * epsilon * H.pot.velocity(p_1ae)
215223
grad_e = H.dlogp(q_e)
216224
p_e = p_1ae + floatX(a) * epsilon * grad_e
225+
v_e = H.pot.velocity(p_e)
217226

218227
new_energy = energy(H, q_e, p_e)
219228

220229
f = theano.function(inputs=[q, p, q_grad, epsilon],
221-
outputs=[q_e, p_e, grad_e, new_energy],
230+
outputs=[q_e, p_e, v_e, grad_e, new_energy],
222231
**theano_kwargs)
223232
f.trust_input = True
224233
return f
@@ -250,10 +259,11 @@ def _theano_single_twostage(H, q, p, q_grad, **theano_kwargs):
250259
q_e = q_e2 + epsilon / 2 * H.pot.velocity(p_1ae)
251260
grad_e = H.dlogp(q_e)
252261
p_e = p_1ae + a * epsilon * grad_e
262+
v_e = H.pot.velocity(p_e)
253263

254264
new_energy = energy(H, q_e, p_e)
255265
f = theano.function(inputs=[q, p, q_grad, epsilon],
256-
outputs=[q_e, p_e, grad_e, new_energy],
266+
outputs=[q_e, p_e, v_e, grad_e, new_energy],
257267
**theano_kwargs)
258268
f.trust_input = True
259269
return f
@@ -273,15 +283,17 @@ def _theano_single_leapfrog(H, q, p, q_grad, **theano_kwargs):
273283
q_new_grad = H.dlogp(q_new)
274284
p_new += 0.5 * epsilon * q_new_grad # half momentum update
275285
energy_new = energy(H, q_new, p_new)
286+
v_new = H.pot.velocity(p_new)
276287

277288
f = theano.function(inputs=[q, p, q_grad, epsilon],
278-
outputs=[q_new, p_new, q_new_grad, energy_new], **theano_kwargs)
289+
outputs=[q_new, p_new, v_new, q_new_grad, energy_new],
290+
**theano_kwargs)
279291
f.trust_input = True
280292
return f
281293

282294

283295
INTEGRATORS_SINGLE = {
284296
'leapfrog': _theano_single_leapfrog,
285297
'two-stage': _theano_single_twostage,
286-
'three-stage': _theano_single_threestage
298+
'three-stage': _theano_single_threestage,
287299
}

pymc3/tests/sampler_fixtures.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,6 @@ def make_model(cls):
5757
a = pm.Uniform("a", lower=-1, upper=1)
5858
return model
5959

60-
def test_interval(self):
61-
a = self.samples['a']
62-
npt.assert_almost_equal(((a > 0.1) & (a < 0.5)).mean(), 0.2, 2)
63-
6460

6561
class NormalFixture(KnownMean, KnownVariance, KnownCDF):
6662
means = {'a': 2 * np.ones(10)}
@@ -86,9 +82,8 @@ def make_model(cls):
8682
return model
8783

8884

89-
class StudentTFixture(KnownMean, KnownVariance, KnownCDF):
85+
class StudentTFixture(KnownMean, KnownCDF):
9086
means = {'a': 0}
91-
variances = {'a': 3}
9287
cdfs = {'a': stats.t(df=3).cdf}
9388
ks_thin = 10
9489

pymc3/tests/test_hmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ def test_leapfrog_reversible_single():
4646

4747
energy = step.compute_energy(q, p)
4848
for _ in range(n_steps):
49-
q, p, dlogp, _ = step.leapfrog(q, p, dlogp, np.array(epsilon))
49+
q, p, v, dlogp, _ = step.leapfrog(q, p, dlogp, np.array(epsilon))
5050
p = -p
5151
for _ in range(n_steps):
52-
q, p, dlogp, _ = step.leapfrog(q, p, dlogp, np.array(epsilon))
52+
q, p, v, dlogp, _ = step.leapfrog(q, p, dlogp, np.array(epsilon))
5353

5454
close_to(q, q0, 1e-8, str(('q', method, n_steps, epsilon)))
5555
close_to(-p, p0, 1e-8, str(('p', method, n_steps, epsilon)))

pymc3/tests/test_posteriors.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,12 @@ class NUTSNormal(sf.NutsFixture, sf.NormalFixture):
6262

6363

6464
class NUTSBetaBinomial(sf.NutsFixture, sf.BetaBinomialFixture):
65-
n_samples = 10000
65+
n_samples = 2000
66+
ks_thin = 5
6667
tune = 1000
6768
burn = 1000
6869
chains = 2
69-
min_n_eff = 2000
70-
rtol = 0.1
71-
atol = 0.05
70+
min_n_eff = 400
7271

7372

7473
@attr('extra')

0 commit comments

Comments
 (0)