Skip to content

Commit 2e80cef

Browse files
committed
Add base and NDArray backend
This commit contains a new backend for sampling and selecting values. Non-backend files have been changed to work with the new backend. Everything seems to be working with the exception of two issues (marked with FIXME): 1. pymc.plots.forestplot has not been updated yet for the new backend. 2. The previous behavior of passing a trace object to sample is not the same. I updated stochastic_volatility to do this with the same trace object. This commit also introduces a change to `sample`/`psample`. Instead of having separate function, `sample` now takes a keyword argument `threads`, and if this is over one, the multiprocessing version is used. The method for selecting values has also been changed. Traces can still be indexed to return values, a new slice, or a point (depending on the index), but the handling of chains is different. The trace object is now manages multiple chains itself instead of having a separate class to manage the single trace object. `get_values` is the main method for selecting values. By default, it returns separate results for all the chains. The chains can be combine with the `combine` flags, and particular chains can be select with the `chains` argument. The motivation for both sample and selection changes above was to have a unified interface for dealing with multiple chains, as most people are likely going to take advantage of the parallel sampling.
1 parent 683507b commit 2e80cef

20 files changed

+1153
-489
lines changed

pymc/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .distributions import *
55
from .math import *
66

7-
from .trace import *
7+
88
from .sampling import *
99
from .stats import summary
1010
from .step_methods import *

pymc/backends/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from pymc.backends.ndarray import NDArray

pymc/backends/base.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
"""Base backend for traces
2+
3+
These are the base classes for all trace backends. They define all the
4+
required methods for sampling and value selection that should be
5+
overridden or implementented in children classes. See the docstring for
6+
pymc.backends for more information (includng creating custom backends).
7+
"""
8+
import numpy as np
9+
from pymc.model import modelcontext
10+
11+
12+
class Backend(object):
13+
14+
def __init__(self, name, model=None, variables=None):
15+
self.name = name
16+
17+
## model attributes
18+
self.variables = None
19+
self.var_names = None
20+
self.var_shapes = None
21+
self._fn = None
22+
23+
model = modelcontext(model)
24+
self.model = model
25+
if model:
26+
self._setup_model(model, variables)
27+
28+
## set by setup_samples
29+
self.chain = None
30+
self.trace = None
31+
32+
self._draws = {}
33+
34+
def _setup_model(self, model, variables):
35+
if variables is None:
36+
variables = model.unobserved_RVs
37+
self.variables = variables
38+
self.var_names = [str(var) for var in variables]
39+
self._fn = model.fastfn(variables)
40+
41+
var_values = zip(self.var_names, self._fn(model.test_point))
42+
self.var_shapes = {var: value.shape
43+
for var, value in var_values}
44+
45+
def setup_samples(self, draws, chain):
46+
"""Prepare structure to store traces
47+
48+
Parameters
49+
----------
50+
draws : int
51+
Number of sampling iterations
52+
chain : int
53+
Chain number to store trace under
54+
"""
55+
self.chain = chain
56+
self._draws[chain] = draws
57+
58+
if self.trace is None:
59+
self.trace = self._initialize_trace()
60+
trace = self.trace
61+
trace._draws[chain] = draws
62+
trace.backend = self
63+
64+
trace.samples[chain] = {}
65+
for var_name, var_shape in self.var_shapes.items():
66+
trace_shape = [draws] + list(var_shape)
67+
trace.samples[chain][var_name] = self._create_trace(chain,
68+
var_name,
69+
trace_shape)
70+
71+
def record(self, point, draw):
72+
"""Record the value of the current iteration
73+
74+
Parameters
75+
----------
76+
point : dict
77+
Map of point values to variable names
78+
draw : int
79+
Current sampling iteration
80+
"""
81+
for var_name, value in zip(self.var_names, self._fn(point)):
82+
self._store_value(draw,
83+
self.trace.samples[self.chain][var_name],
84+
value)
85+
86+
def clean_interrupt(self, current_draw):
87+
"""Clean up sampling after interruption
88+
89+
Perform any clean up not taken care of by `close`. After
90+
KeyboardInterrupt, `sample` calls `close`, so `close` should not
91+
be called here.
92+
"""
93+
self.trace._draws[self.chain] = current_draw
94+
95+
## Sampling methods that children must define
96+
97+
def _initialize_trace(self):
98+
raise NotImplementedError
99+
100+
def _create_trace(self, chain, var_name, shape):
101+
"""Create trace for a variable
102+
103+
Parameters
104+
----------
105+
chain : int
106+
Current chain number
107+
var_name : str
108+
Name of variable
109+
shape : tuple
110+
Shape of the trace. The first element corresponds to the
111+
number of draws.
112+
"""
113+
raise NotImplementedError
114+
115+
def _store_value(self, draw, var_trace, value):
116+
raise NotImplementedError
117+
118+
def commit(self):
119+
"""Commit samples to backend
120+
121+
This is called at set intervals during sampling.
122+
"""
123+
raise NotImplementedError
124+
125+
def close(self):
126+
"""Close the database backend
127+
128+
This is called after sampling has finished.
129+
"""
130+
raise NotImplementedError
131+
132+
133+
class Trace(object):
134+
"""
135+
Parameters
136+
----------
137+
var_names : list of strs
138+
Sample variables names
139+
backend : Backend object
140+
141+
Attributes
142+
----------
143+
backend : Backend object
144+
var_names
145+
var_shapes : dict
146+
Map of variables shape to variable names
147+
samples : dict of dicts
148+
Sample values keyed by chain and variable name
149+
nchains : int
150+
Number of sampling chains
151+
chains : list of ints
152+
List of sampling chain numbers
153+
default_chain : int
154+
Chain to be used if single chain requested
155+
active_chains : list of ints
156+
Values from chains to be used operations
157+
"""
158+
def __init__(self, var_names, backend=None):
159+
self.var_names = var_names
160+
161+
self.samples = {}
162+
self._draws = {}
163+
self.backend = backend
164+
self._active_chains = []
165+
self._default_chain = None
166+
167+
@property
168+
def nchains(self):
169+
"""Number of chains
170+
171+
A chain is created for each sample call (including parallel
172+
threads).
173+
"""
174+
return len(self.samples)
175+
176+
@property
177+
def chains(self):
178+
"""All chains in trace"""
179+
return list(self.samples.keys())
180+
181+
@property
182+
def default_chain(self):
183+
"""Default chain to use for operations that require one chain (e.g.,
184+
`point`)
185+
"""
186+
if self._default_chain is None:
187+
return self.active_chains[-1]
188+
return self._default_chain
189+
190+
@default_chain.setter
191+
def default_chain(self, value):
192+
self._default_chain = value
193+
194+
@property
195+
def active_chains(self):
196+
"""List of chains to be used. Defaults to all.
197+
"""
198+
if not self._active_chains:
199+
return self.chains
200+
return self._active_chains
201+
202+
@active_chains.setter
203+
def active_chains(self, values):
204+
try:
205+
self._active_chains = [chain for chain in values]
206+
except TypeError:
207+
self._active_chains = [values]
208+
209+
def __len__(self):
210+
return self._draws[self.default_chain]
211+
212+
def __getitem__(self, idx):
213+
if isinstance(idx, slice):
214+
return self._slice(idx)
215+
216+
try:
217+
return self.point(idx)
218+
except ValueError:
219+
pass
220+
except TypeError:
221+
pass
222+
return self.get_values(idx)
223+
224+
## Selection methods that children must define
225+
226+
def get_values(self, var_name, burn=0, thin=1, combine=False, chains=None,
227+
squeeze=True):
228+
"""Get values from samples
229+
230+
Parameters
231+
----------
232+
var_name : str
233+
burn : int
234+
thin : int
235+
combine : bool
236+
If True, results from all chains will be concatenated.
237+
chains : list
238+
Chains to retrieve. If None, `active_chains` is used.
239+
squeeze : bool
240+
If `combine` is False, return a single array element if the
241+
resulting list of values only has one element (even if
242+
`combine` is True).
243+
244+
Returns
245+
-------
246+
A list of NumPy array of values
247+
"""
248+
raise NotImplementedError
249+
250+
def _slice(self, idx):
251+
"""Slice trace object"""
252+
raise NotImplementedError
253+
254+
def point(self, idx, chain=None):
255+
"""Return dictionary of point values at `idx` for current chain
256+
with variables names as keys.
257+
258+
If `chain` is not specified, `default_chain` is used.
259+
"""
260+
raise NotImplementedError
261+
262+
263+
def merge_chains(traces):
264+
"""Merge chains from trace instances
265+
266+
Parameters
267+
----------
268+
traces : list
269+
Backend trace instances. Each instance should have only one
270+
chain, and all chain numbers should be unique.
271+
272+
Raises
273+
------
274+
ValueError is raised if any traces have the same current chain
275+
number.
276+
277+
Returns
278+
-------
279+
Backend instance with merge chains
280+
"""
281+
base_trace = traces[0]
282+
for new_trace in traces[1:]:
283+
new_chain = new_trace.chains[0]
284+
if new_chain in base_trace.samples:
285+
raise ValueError('Trace chain numbers conflict.')
286+
base_trace.samples[new_chain] = new_trace.samples[new_chain]
287+
return base_trace
288+
289+
290+
def _squeeze_cat(results, combine, squeeze):
291+
"""Squeeze and concatenate the results dependending on values of
292+
`combine` and `squeeze`"""
293+
if combine:
294+
results = np.concatenate(results)
295+
if not squeeze:
296+
results = [results]
297+
else:
298+
if squeeze and len(results) == 1:
299+
results = results[0]
300+
return results

0 commit comments

Comments
 (0)