40
40
41
41
42
42
class ParallelSamplingError (Exception ):
43
- def __init__ (self , message , chain , warnings = None ):
43
+ def __init__ (self , message , chain ):
44
44
super ().__init__ (message )
45
- if warnings is None :
46
- warnings = []
47
45
self ._chain = chain
48
- self ._warnings = warnings
49
46
50
47
51
48
# Taken from https://hg.python.org/cpython/rev/c4f92b597074
@@ -74,8 +71,8 @@ def rebuild_exc(exc, tb):
74
71
75
72
76
73
# Messages
77
- # ('writing_done', is_last, sample_idx, tuning, stats, warns )
78
- # ('error', warnings, *exception_info)
74
+ # ('writing_done', is_last, sample_idx, tuning, stats)
75
+ # ('error', *exception_info)
79
76
80
77
# ('abort', reason)
81
78
# ('write_next',)
@@ -133,7 +130,7 @@ def run(self):
133
130
e = ExceptionWithTraceback (e , e .__traceback__ )
134
131
# Send is not blocking so we have to force a wait for the abort
135
132
# message
136
- self ._msg_pipe .send (("error" , None , e ))
133
+ self ._msg_pipe .send (("error" , e ))
137
134
self ._wait_for_abortion ()
138
135
finally :
139
136
self ._msg_pipe .close ()
@@ -181,9 +178,8 @@ def _start_loop(self):
181
178
try :
182
179
point , stats = self ._compute_point ()
183
180
except SamplingError as e :
184
- warns = self ._collect_warnings ()
185
181
e = ExceptionWithTraceback (e , e .__traceback__ )
186
- self ._msg_pipe .send (("error" , warns , e ))
182
+ self ._msg_pipe .send (("error" , e ))
187
183
else :
188
184
return
189
185
@@ -193,11 +189,7 @@ def _start_loop(self):
193
189
elif msg [0 ] == "write_next" :
194
190
self ._write_point (point )
195
191
is_last = draw + 1 == self ._draws + self ._tune
196
- if is_last :
197
- warns = self ._collect_warnings ()
198
- else :
199
- warns = None
200
- self ._msg_pipe .send (("writing_done" , is_last , draw , tuning , stats , warns ))
192
+ self ._msg_pipe .send (("writing_done" , is_last , draw , tuning , stats ))
201
193
draw += 1
202
194
else :
203
195
raise ValueError ("Unknown message " + msg [0 ])
@@ -210,12 +202,6 @@ def _compute_point(self):
210
202
stats = None
211
203
return point , stats
212
204
213
- def _collect_warnings (self ):
214
- if hasattr (self ._step_method , "warnings" ):
215
- return self ._step_method .warnings ()
216
- else :
217
- return []
218
-
219
205
220
206
def _run_process (* args ):
221
207
_Process (* args ).run ()
@@ -308,11 +294,13 @@ def _send(self, msg, *args):
308
294
except Exception :
309
295
pass
310
296
if message is not None and message [0 ] == "error" :
311
- warns , old_error = message [1 :]
312
- if warns is not None :
313
- error = ParallelSamplingError (str (old_error ), self .chain , warns )
297
+ old_error = message [1 ]
298
+ if old_error is not None :
299
+ error = ParallelSamplingError (
300
+ f"Chain { self .chain } failed with: { old_error } " , self .chain
301
+ )
314
302
else :
315
- error = RuntimeError ("Chain %s failed." % self . chain )
303
+ error = RuntimeError (f "Chain { self . chain } failed." )
316
304
raise error from old_error
317
305
raise
318
306
@@ -345,11 +333,13 @@ def recv_draw(processes, timeout=3600):
345
333
msg = ready [0 ].recv ()
346
334
347
335
if msg [0 ] == "error" :
348
- warns , old_error = msg [1 :]
349
- if warns is not None :
350
- error = ParallelSamplingError (str (old_error ), proc .chain , warns )
336
+ old_error = msg [1 ]
337
+ if old_error is not None :
338
+ error = ParallelSamplingError (
339
+ f"Chain { proc .chain } failed with: { old_error } " , proc .chain
340
+ )
351
341
else :
352
- error = RuntimeError ("Chain %s failed." % proc . chain )
342
+ error = RuntimeError (f "Chain { proc . chain } failed." )
353
343
raise error from old_error
354
344
elif msg [0 ] == "writing_done" :
355
345
proc ._readable = True
@@ -383,7 +373,7 @@ def terminate_all(processes, patience=2):
383
373
process .join ()
384
374
385
375
386
- Draw = namedtuple ("Draw" , ["chain" , "is_last" , "draw_idx" , "tuning" , "stats" , "point" , "warnings" ])
376
+ Draw = namedtuple ("Draw" , ["chain" , "is_last" , "draw_idx" , "tuning" , "stats" , "point" ])
387
377
388
378
389
379
class ParallelSampler :
@@ -466,7 +456,7 @@ def __iter__(self):
466
456
467
457
while self ._active :
468
458
draw = ProcessAdapter .recv_draw (self ._active )
469
- proc , is_last , draw , tuning , stats , warns = draw
459
+ proc , is_last , draw , tuning , stats = draw
470
460
self ._total_draws += 1
471
461
if not tuning and stats and stats [0 ].get ("diverging" ):
472
462
self ._divergences += 1
@@ -491,7 +481,7 @@ def __iter__(self):
491
481
if not is_last :
492
482
proc .write_next ()
493
483
494
- yield Draw (proc .chain , is_last , draw , tuning , stats , point , warns )
484
+ yield Draw (proc .chain , is_last , draw , tuning , stats , point )
495
485
496
486
def __enter__ (self ):
497
487
self ._in_context = True
0 commit comments