Skip to content

Commit 7fb3623

Browse files
Collect sampler warnings only through stats (#6192)
* Collect sampler warnings only through stats Default to `pm.sample(keep_warning_stat=False)` to maintain compatibility with saving InferenceData. Closes #6191
1 parent 413d8c2 commit 7fb3623

16 files changed

+340
-138
lines changed

pymc/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __set_compiler_flags():
7272
from pymc.stats import *
7373
from pymc.step_methods import *
7474
from pymc.tuning import *
75+
from pymc.util import drop_warning_stat
7576
from pymc.variational import *
7677
from pymc.vartypes import *
7778

pymc/backends/base.py

-7
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,6 @@ def __init__(self, name, model=None, vars=None, test_point=None):
7878
self.chain = None
7979
self._is_base_setup = False
8080
self.sampler_vars = None
81-
self._warnings = []
82-
83-
def _add_warnings(self, warnings):
84-
self._warnings.extend(warnings)
8581

8682
# Sampling methods
8783

@@ -288,9 +284,6 @@ def __init__(self, straces):
288284
self._straces[strace.chain] = strace
289285

290286
self._report = SamplerReport()
291-
for strace in straces:
292-
if hasattr(strace, "_warnings"):
293-
self._report._add_warnings(strace._warnings, strace.chain)
294287

295288
def __repr__(self):
296289
template = "<{}: {} chains, {} iterations, {} variables>"

pymc/parallel_sampling.py

+21-31
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,9 @@
4040

4141

4242
class ParallelSamplingError(Exception):
43-
def __init__(self, message, chain, warnings=None):
43+
def __init__(self, message, chain):
4444
super().__init__(message)
45-
if warnings is None:
46-
warnings = []
4745
self._chain = chain
48-
self._warnings = warnings
4946

5047

5148
# Taken from https://hg.python.org/cpython/rev/c4f92b597074
@@ -74,8 +71,8 @@ def rebuild_exc(exc, tb):
7471

7572

7673
# Messages
77-
# ('writing_done', is_last, sample_idx, tuning, stats, warns)
78-
# ('error', warnings, *exception_info)
74+
# ('writing_done', is_last, sample_idx, tuning, stats)
75+
# ('error', *exception_info)
7976

8077
# ('abort', reason)
8178
# ('write_next',)
@@ -133,7 +130,7 @@ def run(self):
133130
e = ExceptionWithTraceback(e, e.__traceback__)
134131
# Send is not blocking so we have to force a wait for the abort
135132
# message
136-
self._msg_pipe.send(("error", None, e))
133+
self._msg_pipe.send(("error", e))
137134
self._wait_for_abortion()
138135
finally:
139136
self._msg_pipe.close()
@@ -181,9 +178,8 @@ def _start_loop(self):
181178
try:
182179
point, stats = self._compute_point()
183180
except SamplingError as e:
184-
warns = self._collect_warnings()
185181
e = ExceptionWithTraceback(e, e.__traceback__)
186-
self._msg_pipe.send(("error", warns, e))
182+
self._msg_pipe.send(("error", e))
187183
else:
188184
return
189185

@@ -193,11 +189,7 @@ def _start_loop(self):
193189
elif msg[0] == "write_next":
194190
self._write_point(point)
195191
is_last = draw + 1 == self._draws + self._tune
196-
if is_last:
197-
warns = self._collect_warnings()
198-
else:
199-
warns = None
200-
self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats, warns))
192+
self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats))
201193
draw += 1
202194
else:
203195
raise ValueError("Unknown message " + msg[0])
@@ -210,12 +202,6 @@ def _compute_point(self):
210202
stats = None
211203
return point, stats
212204

213-
def _collect_warnings(self):
214-
if hasattr(self._step_method, "warnings"):
215-
return self._step_method.warnings()
216-
else:
217-
return []
218-
219205

220206
def _run_process(*args):
221207
_Process(*args).run()
@@ -308,11 +294,13 @@ def _send(self, msg, *args):
308294
except Exception:
309295
pass
310296
if message is not None and message[0] == "error":
311-
warns, old_error = message[1:]
312-
if warns is not None:
313-
error = ParallelSamplingError(str(old_error), self.chain, warns)
297+
old_error = message[1]
298+
if old_error is not None:
299+
error = ParallelSamplingError(
300+
f"Chain {self.chain} failed with: {old_error}", self.chain
301+
)
314302
else:
315-
error = RuntimeError("Chain %s failed." % self.chain)
303+
error = RuntimeError(f"Chain {self.chain} failed.")
316304
raise error from old_error
317305
raise
318306

@@ -345,11 +333,13 @@ def recv_draw(processes, timeout=3600):
345333
msg = ready[0].recv()
346334

347335
if msg[0] == "error":
348-
warns, old_error = msg[1:]
349-
if warns is not None:
350-
error = ParallelSamplingError(str(old_error), proc.chain, warns)
336+
old_error = msg[1]
337+
if old_error is not None:
338+
error = ParallelSamplingError(
339+
f"Chain {proc.chain} failed with: {old_error}", proc.chain
340+
)
351341
else:
352-
error = RuntimeError("Chain %s failed." % proc.chain)
342+
error = RuntimeError(f"Chain {proc.chain} failed.")
353343
raise error from old_error
354344
elif msg[0] == "writing_done":
355345
proc._readable = True
@@ -383,7 +373,7 @@ def terminate_all(processes, patience=2):
383373
process.join()
384374

385375

386-
Draw = namedtuple("Draw", ["chain", "is_last", "draw_idx", "tuning", "stats", "point", "warnings"])
376+
Draw = namedtuple("Draw", ["chain", "is_last", "draw_idx", "tuning", "stats", "point"])
387377

388378

389379
class ParallelSampler:
@@ -466,7 +456,7 @@ def __iter__(self):
466456

467457
while self._active:
468458
draw = ProcessAdapter.recv_draw(self._active)
469-
proc, is_last, draw, tuning, stats, warns = draw
459+
proc, is_last, draw, tuning, stats = draw
470460
self._total_draws += 1
471461
if not tuning and stats and stats[0].get("diverging"):
472462
self._divergences += 1
@@ -491,7 +481,7 @@ def __iter__(self):
491481
if not is_last:
492482
proc.write_next()
493483

494-
yield Draw(proc.chain, is_last, draw, tuning, stats, point, warns)
484+
yield Draw(proc.chain, is_last, draw, tuning, stats, point)
495485

496486
def __enter__(self):
497487
self._in_context = True

pymc/sampling.py

+35-16
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,13 @@
7070
)
7171
from pymc.model import Model, modelcontext
7272
from pymc.parallel_sampling import Draw, _cpu_count
73-
from pymc.stats.convergence import run_convergence_checks
73+
from pymc.stats.convergence import SamplerWarning, log_warning, run_convergence_checks
7474
from pymc.step_methods import NUTS, CompoundStep, DEMetropolis
7575
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
7676
from pymc.step_methods.hmc import quadpotential
7777
from pymc.util import (
7878
dataset_to_point_list,
79+
drop_warning_stat,
7980
get_default_varnames,
8081
get_untransformed_name,
8182
is_transformed_name,
@@ -323,6 +324,7 @@ def sample(
323324
jitter_max_retries: int = 10,
324325
*,
325326
return_inferencedata: bool = True,
327+
keep_warning_stat: bool = False,
326328
idata_kwargs: dict = None,
327329
mp_ctx=None,
328330
**kwargs,
@@ -393,6 +395,13 @@ def sample(
393395
`MultiTrace` (False). Defaults to `True`.
394396
idata_kwargs : dict, optional
395397
Keyword arguments for :func:`pymc.to_inference_data`
398+
keep_warning_stat : bool
399+
If ``True`` the "warning" stat emitted by, for example, HMC samplers will be kept
400+
in the returned ``idata.sample_stat`` group.
401+
This leads to the ``idata`` not supporting ``.to_netcdf()`` or ``.to_zarr()`` and
402+
should only be set to ``True`` if you intend to use the "warning" objects right away.
403+
Defaults to ``False`` such that ``pm.drop_warning_stat`` is applied automatically,
404+
making the ``InferenceData`` compatible with saving.
396405
mp_ctx : multiprocessing.context.BaseContent
397406
A multiprocessing context for parallel sampling.
398407
See multiprocessing documentation for details.
@@ -699,6 +708,10 @@ def sample(
699708
mtrace.report._add_warnings(convergence_warnings)
700709

701710
if return_inferencedata:
711+
# By default we drop the "warning" stat which contains `SamplerWarning`
712+
# objects that can not be stored with `.to_netcdf()`.
713+
if not keep_warning_stat:
714+
return drop_warning_stat(idata)
702715
return idata
703716
return mtrace
704717

@@ -1048,32 +1061,26 @@ def _iter_sample(
10481061
if step.generates_stats:
10491062
point, stats = step.step(point)
10501063
strace.record(point, stats)
1064+
log_warning_stats(stats)
10511065
diverging = i > tune and stats and stats[0].get("diverging")
10521066
else:
10531067
point = step.step(point)
10541068
strace.record(point)
10551069
if callback is not None:
1056-
warns = getattr(step, "warnings", None)
10571070
callback(
10581071
trace=strace,
1059-
draw=Draw(chain, i == draws, i, i < tune, stats, point, warns),
1072+
draw=Draw(chain, i == draws, i, i < tune, stats, point),
10601073
)
10611074

10621075
yield strace, diverging
10631076
except KeyboardInterrupt:
10641077
strace.close()
1065-
if hasattr(step, "warnings"):
1066-
warns = step.warnings()
1067-
strace._add_warnings(warns)
10681078
raise
10691079
except BaseException:
10701080
strace.close()
10711081
raise
10721082
else:
10731083
strace.close()
1074-
if hasattr(step, "warnings"):
1075-
warns = step.warnings()
1076-
strace._add_warnings(warns)
10771084

10781085

10791086
class PopulationStepper:
@@ -1356,6 +1363,7 @@ def _iter_population(
13561363
if steppers[c].generates_stats:
13571364
points[c], stats = updates[c]
13581365
strace.record(points[c], stats)
1366+
log_warning_stats(stats)
13591367
else:
13601368
points[c] = updates[c]
13611369
strace.record(points[c])
@@ -1513,21 +1521,16 @@ def _mp_sample(
15131521
with sampler:
15141522
for draw in sampler:
15151523
strace = traces[draw.chain]
1516-
if draw.stats is not None:
1517-
strace.record(draw.point, draw.stats)
1518-
else:
1519-
strace.record(draw.point)
1524+
strace.record(draw.point, draw.stats)
1525+
log_warning_stats(draw.stats)
15201526
if draw.is_last:
15211527
strace.close()
1522-
if draw.warnings is not None:
1523-
strace._add_warnings(draw.warnings)
15241528

15251529
if callback is not None:
15261530
callback(trace=trace, draw=draw)
15271531

15281532
except ps.ParallelSamplingError as error:
15291533
strace = traces[error._chain]
1530-
strace._add_warnings(error._warnings)
15311534
for strace in traces:
15321535
strace.close()
15331536

@@ -1546,6 +1549,22 @@ def _mp_sample(
15461549
strace.close()
15471550

15481551

1552+
def log_warning_stats(stats: Sequence[Dict[str, Any]]):
1553+
"""Logs 'warning' stats if present."""
1554+
if stats is None:
1555+
return
1556+
1557+
for sts in stats:
1558+
warn = sts.get("warning", None)
1559+
if warn is None:
1560+
continue
1561+
if isinstance(warn, SamplerWarning):
1562+
log_warning(warn)
1563+
else:
1564+
_log.warning(warn)
1565+
return
1566+
1567+
15491568
def _choose_chains(traces: Sequence[BaseTrace], tune: int) -> Tuple[List[BaseTrace], int]:
15501569
"""
15511570
Filter and slice traces such that (n_traces * len(shortest_trace)) is maximized.

pymc/stats/convergence.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> List[SamplerWar
6868
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
6969
return [warn]
7070

71-
warnings = []
71+
warnings: List[SamplerWarning] = []
7272
valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
7373
varnames = []
7474
for rv in model.free_RVs:
@@ -104,11 +104,60 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> List[SamplerWar
104104
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=ess)
105105
warnings.append(warn)
106106

107+
warnings += warn_divergences(idata)
108+
warnings += warn_treedepth(idata)
109+
110+
return warnings
111+
112+
113+
def warn_divergences(idata: arviz.InferenceData) -> List[SamplerWarning]:
114+
"""Checks sampler stats and creates a list of warnings about divergences."""
115+
sampler_stats = idata.get("sample_stats", None)
116+
if sampler_stats is None:
117+
return []
118+
119+
diverging = sampler_stats.get("diverging", None)
120+
if diverging is None:
121+
return []
122+
123+
# Warn about divergences
124+
n_div = int(diverging.sum())
125+
if n_div == 0:
126+
return []
127+
warning = SamplerWarning(
128+
WarningType.DIVERGENCES,
129+
f"There were {n_div} divergences after tuning. Increase `target_accept` or reparameterize.",
130+
"error",
131+
)
132+
return [warning]
133+
134+
135+
def warn_treedepth(idata: arviz.InferenceData) -> List[SamplerWarning]:
136+
"""Checks sampler stats and creates a list of warnings about tree depth."""
137+
sampler_stats = idata.get("sample_stats", None)
138+
if sampler_stats is None:
139+
return []
140+
141+
treedepth = sampler_stats.get("tree_depth", None)
142+
if treedepth is None:
143+
return []
144+
145+
warnings = []
146+
for c in treedepth.chain:
147+
if sum(treedepth.sel(chain=c)) / treedepth.sizes["draw"] > 0.05:
148+
warnings.append(
149+
SamplerWarning(
150+
WarningType.TREEDEPTH,
151+
f"Chain {c} reached the maximum tree depth."
152+
" Increase `max_treedepth`, increase `target_accept` or reparameterize.",
153+
"warn",
154+
)
155+
)
107156
return warnings
108157

109158

110159
def log_warning(warn: SamplerWarning):
111-
level = _LEVELS[warn.level]
160+
level = _LEVELS.get(warn.level, logging.WARNING)
112161
logger.log(level, warn.message)
113162

114163

pymc/step_methods/compound.py

-7
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,6 @@ def step(self, point):
5959
point = method.step(point)
6060
return point
6161

62-
def warnings(self):
63-
warns = []
64-
for method in self.methods:
65-
if hasattr(method, "warnings"):
66-
warns.extend(method.warnings())
67-
return warns
68-
6962
def stop_tuning(self):
7063
for method in self.methods:
7164
method.stop_tuning()

0 commit comments

Comments
 (0)