@@ -311,7 +311,6 @@ def sample(
311
311
n_init : int = 200_000 ,
312
312
initvals : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]] = None ,
313
313
trace : Optional [Union [BaseTrace , List [str ]]] = None ,
314
- chain_idx : int = 0 ,
315
314
chains : Optional [int ] = None ,
316
315
cores : Optional [int ] = None ,
317
316
tune : int = 1000 ,
@@ -353,9 +352,6 @@ def sample(
353
352
trace : backend or list
354
353
This should be a backend instance, or a list of variables to track.
355
354
If None or a list of variables, the NDArray backend is used.
356
- chain_idx : int
357
- Chain number used to store sample in backend. If ``chains`` is greater than one, chain
358
- numbers will start here.
359
355
chains : int
360
356
The number of chains to sample. Running independent chains is important for some
361
357
convergence statistics and can also reveal multiple modes in the posterior. If ``None``,
@@ -569,7 +565,6 @@ def sample(
569
565
"step" : step ,
570
566
"start" : initial_points ,
571
567
"trace" : trace ,
572
- "chain" : chain_idx ,
573
568
"chains" : chains ,
574
569
"tune" : tune ,
575
570
"progressbar" : progressbar ,
@@ -658,7 +653,7 @@ def sample(
658
653
# count the number of tune/draw iterations that happened
659
654
# ideally via the "tune" statistic, but not all samplers record it!
660
655
if "tune" in mtrace .stat_names :
661
- stat = mtrace .get_sampler_stats ("tune" , chains = chain_idx )
656
+ stat = mtrace .get_sampler_stats ("tune" , chains = 0 )
662
657
# when CompoundStep is used, the stat is 2 dimensional!
663
658
if len (stat .shape ) == 2 :
664
659
stat = stat [:, 0 ]
@@ -734,7 +729,6 @@ def _check_start_shape(model, start: PointType):
734
729
735
730
def _sample_many (
736
731
draws : int ,
737
- chain : int ,
738
732
chains : int ,
739
733
start : Sequence [PointType ],
740
734
random_seed : Optional [Sequence [RandomSeed ]],
@@ -748,8 +742,6 @@ def _sample_many(
748
742
----------
749
743
draws: int
750
744
The number of samples to draw
751
- chain: int
752
- Number of the first chain in the sequence.
753
745
chains: int
754
746
Total number of chains to sample.
755
747
start: list
@@ -768,7 +760,7 @@ def _sample_many(
768
760
for i in range (chains ):
769
761
trace = _sample (
770
762
draws = draws ,
771
- chain = chain + i ,
763
+ chain = i ,
772
764
start = start [i ],
773
765
step = step ,
774
766
random_seed = None if random_seed is None else random_seed [i ],
@@ -791,7 +783,6 @@ def _sample_many(
791
783
792
784
def _sample_population (
793
785
draws : int ,
794
- chain : int ,
795
786
chains : int ,
796
787
start : Sequence [PointType ],
797
788
random_seed : RandomSeed ,
@@ -808,8 +799,6 @@ def _sample_population(
808
799
----------
809
800
draws : int
810
801
The number of samples to draw
811
- chain : int
812
- The number of the first chain in the population
813
802
chains : int
814
803
The total number of chains in the population
815
804
start : list
@@ -832,7 +821,6 @@ def _sample_population(
832
821
"""
833
822
sampling = _prepare_iter_population (
834
823
draws ,
835
- [chain + c for c in range (chains )],
836
824
step ,
837
825
start ,
838
826
parallelize ,
@@ -952,8 +940,7 @@ def iter_sample(
952
940
This should be a backend instance, or a list of variables to track.
953
941
If None or a list of variables, the NDArray backend is used.
954
942
chain : int, optional
955
- Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
956
- will start here.
943
+ Chain number used to store sample in backend.
957
944
tune : int, optional
958
945
Number of iterations to tune (defaults to 0).
959
946
model : Model (optional if in ``with`` context)
@@ -1008,8 +995,7 @@ def _iter_sample(
1008
995
This should be a backend instance, or a list of variables to track.
1009
996
If None or a list of variables, the NDArray backend is used.
1010
997
chain : int, optional
1011
- Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
1012
- will start here.
998
+ Chain number used to store sample in backend.
1013
999
tune : int, optional
1014
1000
Number of iterations to tune (defaults to 0).
1015
1001
model : Model (optional if in ``with`` context)
@@ -1247,7 +1233,6 @@ def step(self, tune_stop: bool, population):
1247
1233
1248
1234
def _prepare_iter_population (
1249
1235
draws : int ,
1250
- chains : list ,
1251
1236
step ,
1252
1237
start : Sequence [PointType ],
1253
1238
parallelize : bool ,
@@ -1262,8 +1247,6 @@ def _prepare_iter_population(
1262
1247
----------
1263
1248
draws : int
1264
1249
The number of samples to draw
1265
- chains : list
1266
- The chain numbers in the population
1267
1250
step : function
1268
1251
Step function (should be or contain a population step method)
1269
1252
start : list
@@ -1282,8 +1265,7 @@ def _prepare_iter_population(
1282
1265
_iter_population : generator
1283
1266
Yields traces of all chains at the same time
1284
1267
"""
1285
- # chains contains the chain numbers, but for indexing we need indices...
1286
- nchains = len (chains )
1268
+ nchains = len (start )
1287
1269
model = modelcontext (model )
1288
1270
draws = int (draws )
1289
1271
@@ -1327,7 +1309,7 @@ def _prepare_iter_population(
1327
1309
trace = None ,
1328
1310
model = model ,
1329
1311
)
1330
- for c in chains
1312
+ for c in range ( nchains )
1331
1313
]
1332
1314
1333
1315
# 4. configure the PopulationStepper (expensive call)
@@ -1457,7 +1439,6 @@ def _mp_sample(
1457
1439
step ,
1458
1440
chains : int ,
1459
1441
cores : int ,
1460
- chain : int ,
1461
1442
random_seed : Sequence [RandomSeed ],
1462
1443
start : Sequence [PointType ],
1463
1444
progressbar : bool = True ,
@@ -1482,8 +1463,6 @@ def _mp_sample(
1482
1463
The number of chains to sample.
1483
1464
cores : int
1484
1465
The number of chains to run in parallel.
1485
- chain : int
1486
- Number of the first chain.
1487
1466
random_seed : list of random seeds
1488
1467
Random seeds for each chain.
1489
1468
start : list
@@ -1520,26 +1499,25 @@ def _mp_sample(
1520
1499
trace = trace ,
1521
1500
model = model ,
1522
1501
)
1523
- for chain_number in range (chain , chain + chains )
1502
+ for chain_number in range (chains )
1524
1503
]
1525
1504
1526
1505
sampler = ps .ParallelSampler (
1527
- draws ,
1528
- tune ,
1529
- chains ,
1530
- cores ,
1531
- random_seed ,
1532
- start ,
1533
- step ,
1534
- chain ,
1535
- progressbar ,
1506
+ draws = draws ,
1507
+ tune = tune ,
1508
+ chains = chains ,
1509
+ cores = cores ,
1510
+ seeds = random_seed ,
1511
+ start_points = start ,
1512
+ step_method = step ,
1513
+ progressbar = progressbar ,
1536
1514
mp_ctx = mp_ctx ,
1537
1515
)
1538
1516
try :
1539
1517
try :
1540
1518
with sampler :
1541
1519
for draw in sampler :
1542
- strace = traces [draw .chain - chain ]
1520
+ strace = traces [draw .chain ]
1543
1521
if strace .supports_sampler_stats and draw .stats is not None :
1544
1522
strace .record (draw .point , draw .stats )
1545
1523
else :
@@ -1553,7 +1531,7 @@ def _mp_sample(
1553
1531
callback (trace = trace , draw = draw )
1554
1532
1555
1533
except ps .ParallelSamplingError as error :
1556
- strace = traces [error ._chain - chain ]
1534
+ strace = traces [error ._chain ]
1557
1535
strace ._add_warnings (error ._warnings )
1558
1536
for strace in traces :
1559
1537
strace .close ()
@@ -1998,18 +1976,12 @@ def sample_posterior_predictive(
1998
1976
_log .info (f"Sampling: { list (sorted (volatile_basic_rvs , key = lambda var : var .name ))} " ) # type: ignore
1999
1977
ppc_trace_t = _DefaultTrace (samples )
2000
1978
try :
2001
- if isinstance (_trace , MultiTrace ):
2002
- # trace dict is unordered, but we want to return ppc samples in
2003
- # a predictable ordering, so sort the chain indices
2004
- chain_idx_mapping = sorted (_trace ._straces .keys ())
2005
1979
for idx in indices :
2006
1980
if nchain > 1 :
2007
1981
# the trace object will either be a MultiTrace (and have _straces)...
2008
1982
if hasattr (_trace , "_straces" ):
2009
1983
chain_idx , point_idx = np .divmod (idx , len_trace )
2010
1984
chain_idx = chain_idx % nchain
2011
- # chain indices might not always start at 0, convert to proper index
2012
- chain_idx = chain_idx_mapping [chain_idx ]
2013
1985
param = cast (MultiTrace , _trace )._straces [chain_idx ].point (point_idx )
2014
1986
# ... or a PointList
2015
1987
else :
0 commit comments