Skip to content

Commit 91dbfd2

Browse files
michaelosthegericardoV94
authored andcommitted
Drop support for custom chain numbering start
1 parent 9abf4e0 commit 91dbfd2

File tree

4 files changed

+29
-64
lines changed

4 files changed

+29
-64
lines changed

pymc/parallel_sampling.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -389,14 +389,14 @@ def terminate_all(processes, patience=2):
389389
class ParallelSampler:
390390
def __init__(
391391
self,
392+
*,
392393
draws: int,
393394
tune: int,
394395
chains: int,
395396
cores: int,
396397
seeds: Sequence["RandomSeed"],
397398
start_points: Sequence[Dict[str, np.ndarray]],
398399
step_method,
399-
start_chain_num: int = 0,
400400
progressbar: bool = True,
401401
mp_ctx=None,
402402
):
@@ -420,7 +420,7 @@ def __init__(
420420
tune,
421421
step_method,
422422
step_method_pickled,
423-
chain + start_chain_num,
423+
chain,
424424
seed,
425425
start,
426426
mp_ctx,
@@ -434,7 +434,6 @@ def __init__(
434434
self._max_active = cores
435435

436436
self._in_context = False
437-
self._start_chain_num = start_chain_num
438437

439438
self._progress = None
440439
self._divergences = 0

pymc/sampling.py

+17-45
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,6 @@ def sample(
311311
n_init: int = 200_000,
312312
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
313313
trace: Optional[Union[BaseTrace, List[str]]] = None,
314-
chain_idx: int = 0,
315314
chains: Optional[int] = None,
316315
cores: Optional[int] = None,
317316
tune: int = 1000,
@@ -353,9 +352,6 @@ def sample(
353352
trace : backend or list
354353
This should be a backend instance, or a list of variables to track.
355354
If None or a list of variables, the NDArray backend is used.
356-
chain_idx : int
357-
Chain number used to store sample in backend. If ``chains`` is greater than one, chain
358-
numbers will start here.
359355
chains : int
360356
The number of chains to sample. Running independent chains is important for some
361357
convergence statistics and can also reveal multiple modes in the posterior. If ``None``,
@@ -569,7 +565,6 @@ def sample(
569565
"step": step,
570566
"start": initial_points,
571567
"trace": trace,
572-
"chain": chain_idx,
573568
"chains": chains,
574569
"tune": tune,
575570
"progressbar": progressbar,
@@ -658,7 +653,7 @@ def sample(
658653
# count the number of tune/draw iterations that happened
659654
# ideally via the "tune" statistic, but not all samplers record it!
660655
if "tune" in mtrace.stat_names:
661-
stat = mtrace.get_sampler_stats("tune", chains=chain_idx)
656+
stat = mtrace.get_sampler_stats("tune", chains=0)
662657
# when CompoundStep is used, the stat is 2 dimensional!
663658
if len(stat.shape) == 2:
664659
stat = stat[:, 0]
@@ -734,7 +729,6 @@ def _check_start_shape(model, start: PointType):
734729

735730
def _sample_many(
736731
draws: int,
737-
chain: int,
738732
chains: int,
739733
start: Sequence[PointType],
740734
random_seed: Optional[Sequence[RandomSeed]],
@@ -748,8 +742,6 @@ def _sample_many(
748742
----------
749743
draws: int
750744
The number of samples to draw
751-
chain: int
752-
Number of the first chain in the sequence.
753745
chains: int
754746
Total number of chains to sample.
755747
start: list
@@ -768,7 +760,7 @@ def _sample_many(
768760
for i in range(chains):
769761
trace = _sample(
770762
draws=draws,
771-
chain=chain + i,
763+
chain=i,
772764
start=start[i],
773765
step=step,
774766
random_seed=None if random_seed is None else random_seed[i],
@@ -791,7 +783,6 @@ def _sample_many(
791783

792784
def _sample_population(
793785
draws: int,
794-
chain: int,
795786
chains: int,
796787
start: Sequence[PointType],
797788
random_seed: RandomSeed,
@@ -808,8 +799,6 @@ def _sample_population(
808799
----------
809800
draws : int
810801
The number of samples to draw
811-
chain : int
812-
The number of the first chain in the population
813802
chains : int
814803
The total number of chains in the population
815804
start : list
@@ -832,7 +821,6 @@ def _sample_population(
832821
"""
833822
sampling = _prepare_iter_population(
834823
draws,
835-
[chain + c for c in range(chains)],
836824
step,
837825
start,
838826
parallelize,
@@ -952,8 +940,7 @@ def iter_sample(
952940
This should be a backend instance, or a list of variables to track.
953941
If None or a list of variables, the NDArray backend is used.
954942
chain : int, optional
955-
Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
956-
will start here.
943+
Chain number used to store sample in backend.
957944
tune : int, optional
958945
Number of iterations to tune (defaults to 0).
959946
model : Model (optional if in ``with`` context)
@@ -1008,8 +995,7 @@ def _iter_sample(
1008995
This should be a backend instance, or a list of variables to track.
1009996
If None or a list of variables, the NDArray backend is used.
1010997
chain : int, optional
1011-
Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
1012-
will start here.
998+
Chain number used to store sample in backend.
1013999
tune : int, optional
10141000
Number of iterations to tune (defaults to 0).
10151001
model : Model (optional if in ``with`` context)
@@ -1247,7 +1233,6 @@ def step(self, tune_stop: bool, population):
12471233

12481234
def _prepare_iter_population(
12491235
draws: int,
1250-
chains: list,
12511236
step,
12521237
start: Sequence[PointType],
12531238
parallelize: bool,
@@ -1262,8 +1247,6 @@ def _prepare_iter_population(
12621247
----------
12631248
draws : int
12641249
The number of samples to draw
1265-
chains : list
1266-
The chain numbers in the population
12671250
step : function
12681251
Step function (should be or contain a population step method)
12691252
start : list
@@ -1282,8 +1265,7 @@ def _prepare_iter_population(
12821265
_iter_population : generator
12831266
Yields traces of all chains at the same time
12841267
"""
1285-
# chains contains the chain numbers, but for indexing we need indices...
1286-
nchains = len(chains)
1268+
nchains = len(start)
12871269
model = modelcontext(model)
12881270
draws = int(draws)
12891271

@@ -1327,7 +1309,7 @@ def _prepare_iter_population(
13271309
trace=None,
13281310
model=model,
13291311
)
1330-
for c in chains
1312+
for c in range(nchains)
13311313
]
13321314

13331315
# 4. configure the PopulationStepper (expensive call)
@@ -1457,7 +1439,6 @@ def _mp_sample(
14571439
step,
14581440
chains: int,
14591441
cores: int,
1460-
chain: int,
14611442
random_seed: Sequence[RandomSeed],
14621443
start: Sequence[PointType],
14631444
progressbar: bool = True,
@@ -1482,8 +1463,6 @@ def _mp_sample(
14821463
The number of chains to sample.
14831464
cores : int
14841465
The number of chains to run in parallel.
1485-
chain : int
1486-
Number of the first chain.
14871466
random_seed : list of random seeds
14881467
Random seeds for each chain.
14891468
start : list
@@ -1520,26 +1499,25 @@ def _mp_sample(
15201499
trace=trace,
15211500
model=model,
15221501
)
1523-
for chain_number in range(chain, chain + chains)
1502+
for chain_number in range(chains)
15241503
]
15251504

15261505
sampler = ps.ParallelSampler(
1527-
draws,
1528-
tune,
1529-
chains,
1530-
cores,
1531-
random_seed,
1532-
start,
1533-
step,
1534-
chain,
1535-
progressbar,
1506+
draws=draws,
1507+
tune=tune,
1508+
chains=chains,
1509+
cores=cores,
1510+
seeds=random_seed,
1511+
start_points=start,
1512+
step_method=step,
1513+
progressbar=progressbar,
15361514
mp_ctx=mp_ctx,
15371515
)
15381516
try:
15391517
try:
15401518
with sampler:
15411519
for draw in sampler:
1542-
strace = traces[draw.chain - chain]
1520+
strace = traces[draw.chain]
15431521
if strace.supports_sampler_stats and draw.stats is not None:
15441522
strace.record(draw.point, draw.stats)
15451523
else:
@@ -1553,7 +1531,7 @@ def _mp_sample(
15531531
callback(trace=trace, draw=draw)
15541532

15551533
except ps.ParallelSamplingError as error:
1556-
strace = traces[error._chain - chain]
1534+
strace = traces[error._chain]
15571535
strace._add_warnings(error._warnings)
15581536
for strace in traces:
15591537
strace.close()
@@ -1998,18 +1976,12 @@ def sample_posterior_predictive(
19981976
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
19991977
ppc_trace_t = _DefaultTrace(samples)
20001978
try:
2001-
if isinstance(_trace, MultiTrace):
2002-
# trace dict is unordered, but we want to return ppc samples in
2003-
# a predictable ordering, so sort the chain indices
2004-
chain_idx_mapping = sorted(_trace._straces.keys())
20051979
for idx in indices:
20061980
if nchain > 1:
20071981
# the trace object will either be a MultiTrace (and have _straces)...
20081982
if hasattr(_trace, "_straces"):
20091983
chain_idx, point_idx = np.divmod(idx, len_trace)
20101984
chain_idx = chain_idx % nchain
2011-
# chain indices might not always start at 0, convert to proper index
2012-
chain_idx = chain_idx_mapping[chain_idx]
20131985
param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx)
20141986
# ... or a PointList
20151987
else:

pymc/tests/test_parallel_sampling.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,16 @@ def test_iterator():
183183
step = pm.CompoundStep([step1, step2])
184184

185185
start = {"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))}
186-
sampler = ps.ParallelSampler(10, 10, 3, 2, [2, 3, 4], [start] * 3, step, 0, False)
186+
sampler = ps.ParallelSampler(
187+
draws=10,
188+
tune=10,
189+
chains=3,
190+
cores=2,
191+
seeds=[2, 3, 4],
192+
start_points=[start] * 3,
193+
step_method=step,
194+
progressbar=False,
195+
)
187196
with sampler:
188197
for draw in sampler:
189198
pass

pymc/tests/test_sampling.py

-15
Original file line numberDiff line numberDiff line change
@@ -505,21 +505,6 @@ def test_partial_trace_sample():
505505
assert "b" not in idata.posterior
506506

507507

508-
def test_chain_idx():
509-
# see https://github.com/pymc-devs/pymc/issues/4469
510-
with pm.Model():
511-
mu = pm.Normal("mu")
512-
x = pm.Normal("x", mu=mu, sigma=1, observed=np.asarray(3))
513-
# note draws-tune must be >100 AND we need an observed RV for this to properly
514-
# trigger convergence checks, which is one particular case in which this failed
515-
# before
516-
idata = pm.sample(draws=150, tune=10, chain_idx=1)
517-
518-
ppc = pm.sample_posterior_predictive(idata)
519-
# TODO FIXME: Assert something.
520-
ppc = pm.sample_posterior_predictive(idata, keep_size=True)
521-
522-
523508
@pytest.mark.parametrize(
524509
"n_points, tune, expected_length, expected_n_traces",
525510
[

0 commit comments

Comments
 (0)