Skip to content

Commit a574ed5

Browse files
michaelosthegericardoV94
authored andcommitted
Extract convergence checks function from SamplerReport
1 parent 91dbfd2 commit a574ed5

File tree

6 files changed

+136
-116
lines changed

6 files changed

+136
-116
lines changed

pymc/backends/report.py

+12-111
Original file line numberDiff line numberDiff line change
@@ -13,63 +13,28 @@
1313
# limitations under the License.
1414

1515
import dataclasses
16-
import enum
1716
import logging
1817

19-
from typing import Any, Optional
18+
from typing import Optional
2019

2120
import arviz
2221

23-
from pymc.util import get_untransformed_name, is_transformed_name
22+
from pymc.stats.convergence import (
23+
_LEVELS,
24+
SamplerWarning,
25+
log_warnings,
26+
run_convergence_checks,
27+
)
2428

2529
logger = logging.getLogger("pymc")
2630

2731

28-
@enum.unique
29-
class WarningType(enum.Enum):
30-
# For HMC and NUTS
31-
DIVERGENCE = 1
32-
TUNING_DIVERGENCE = 2
33-
DIVERGENCES = 3
34-
TREEDEPTH = 4
35-
# Problematic sampler parameters
36-
BAD_PARAMS = 5
37-
# Indications that chains did not converge, eg Rhat
38-
CONVERGENCE = 6
39-
BAD_ACCEPTANCE = 7
40-
BAD_ENERGY = 8
41-
42-
43-
@dataclasses.dataclass
44-
class SamplerWarning:
45-
kind: WarningType
46-
message: str
47-
level: str
48-
step: Optional[int] = None
49-
exec_info: Optional[Any] = None
50-
extra: Optional[Any] = None
51-
divergence_point_source: Optional[dict] = None
52-
divergence_point_dest: Optional[dict] = None
53-
divergence_info: Optional[Any] = None
54-
55-
56-
_LEVELS = {
57-
"info": logging.INFO,
58-
"error": logging.ERROR,
59-
"warn": logging.WARN,
60-
"debug": logging.DEBUG,
61-
"critical": logging.CRITICAL,
62-
}
63-
64-
6532
class SamplerReport:
6633
"""Bundle warnings, convergence stats and metadata of a sampling run."""
6734

6835
def __init__(self):
69-
self._chain_warnings = {}
70-
self._global_warnings = []
71-
self._ess = None
72-
self._rhat = None
36+
self._chain_warnings: Dict[int, List[SamplerWarning]] = {}
37+
self._global_warnings: List[SamplerWarning] = []
7338
self._n_tune = None
7439
self._n_draws = None
7540
self._t_sampling = None
@@ -109,65 +74,7 @@ def raise_ok(self, level="error"):
10974
raise ValueError("Serious convergence issues during sampling.")
11075

11176
def _run_convergence_checks(self, idata: arviz.InferenceData, model):
112-
if not hasattr(idata, "posterior"):
113-
msg = "No posterior samples. Unable to run convergence checks"
114-
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None)
115-
self._add_warnings([warn])
116-
return
117-
118-
if idata["posterior"].sizes["chain"] == 1:
119-
msg = (
120-
"Only one chain was sampled, this makes it impossible to "
121-
"run some convergence checks"
122-
)
123-
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
124-
self._add_warnings([warn])
125-
return
126-
127-
elif idata["posterior"].sizes["chain"] < 4:
128-
msg = (
129-
"We recommend running at least 4 chains for robust computation of "
130-
"convergence diagnostics"
131-
)
132-
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
133-
self._add_warnings([warn])
134-
return
135-
136-
valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
137-
varnames = []
138-
for rv in model.free_RVs:
139-
rv_name = rv.name
140-
if is_transformed_name(rv_name):
141-
rv_name2 = get_untransformed_name(rv_name)
142-
rv_name = rv_name2 if rv_name2 in valid_name else rv_name
143-
if rv_name in idata["posterior"]:
144-
varnames.append(rv_name)
145-
146-
self._ess = ess = arviz.ess(idata, var_names=varnames)
147-
self._rhat = rhat = arviz.rhat(idata, var_names=varnames)
148-
149-
warnings = []
150-
rhat_max = max(val.max() for val in rhat.values())
151-
if rhat_max > 1.01:
152-
msg = (
153-
"The rhat statistic is larger than 1.01 for some "
154-
"parameters. This indicates problems during sampling. "
155-
"See https://arxiv.org/abs/1903.08008 for details"
156-
)
157-
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "info", extra=rhat)
158-
warnings.append(warn)
159-
160-
eff_min = min(val.min() for val in ess.values())
161-
eff_per_chain = eff_min / idata["posterior"].sizes["chain"]
162-
if eff_per_chain < 100:
163-
msg = (
164-
"The effective sample size per chain is smaller than 100 for some parameters. "
165-
" A higher number is needed for reliable rhat and ess computation. "
166-
"See https://arxiv.org/abs/1903.08008 for details"
167-
)
168-
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=ess)
169-
warnings.append(warn)
170-
77+
warnings = run_convergence_checks(idata, model)
17178
self._add_warnings(warnings)
17279

17380
def _add_warnings(self, warnings, chain=None):
@@ -178,15 +85,9 @@ def _add_warnings(self, warnings, chain=None):
17885
warn_list.extend(warnings)
17986

18087
def _log_summary(self):
181-
def log_warning(warn):
182-
level = _LEVELS[warn.level]
183-
logger.log(level, warn.message)
184-
18588
for chain, warns in self._chain_warnings.items():
186-
for warn in warns:
187-
log_warning(warn)
188-
for warn in self._global_warnings:
189-
log_warning(warn)
89+
log_warnings(warns)
90+
log_warnings(self._global_warnings)
19091

19192
def _slice(self, start, stop, step):
19293
report = SamplerReport()

pymc/sampling.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
)
7171
from pymc.model import Model, modelcontext
7272
from pymc.parallel_sampling import Draw, _cpu_count
73+
from pymc.stats.convergence import run_convergence_checks
7374
from pymc.step_methods import NUTS, CompoundStep, DEMetropolis
7475
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
7576
from pymc.step_methods.hmc import quadpotential
@@ -677,7 +678,7 @@ def sample(
677678
_log.info(
678679
f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations '
679680
f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) "
680-
f"took {mtrace.report.t_sampling:.0f} seconds."
681+
f"took {t_sampling:.0f} seconds."
681682
)
682683
mtrace.report._log_summary()
683684

@@ -695,7 +696,8 @@ def sample(
695696
stacklevel=2,
696697
)
697698
else:
698-
mtrace.report._run_convergence_checks(idata, model)
699+
convergence_warnings = run_convergence_checks(idata, model)
700+
mtrace.report._add_warnings(convergence_warnings)
699701

700702
if return_inferencedata:
701703
return idata

pymc/stats/convergence.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import dataclasses
2+
import enum
3+
import logging
4+
5+
from typing import Any, List, Optional, Sequence
6+
7+
import arviz
8+
9+
from pymc.util import get_untransformed_name, is_transformed_name
10+
11+
_LEVELS = {
12+
"info": logging.INFO,
13+
"error": logging.ERROR,
14+
"warn": logging.WARN,
15+
"debug": logging.DEBUG,
16+
"critical": logging.CRITICAL,
17+
}
18+
19+
logger = logging.getLogger("pymc")
20+
21+
22+
@enum.unique
23+
class WarningType(enum.Enum):
24+
# For HMC and NUTS
25+
DIVERGENCE = 1
26+
TUNING_DIVERGENCE = 2
27+
DIVERGENCES = 3
28+
TREEDEPTH = 4
29+
# Problematic sampler parameters
30+
BAD_PARAMS = 5
31+
# Indications that chains did not converge, eg Rhat
32+
CONVERGENCE = 6
33+
BAD_ACCEPTANCE = 7
34+
BAD_ENERGY = 8
35+
36+
37+
@dataclasses.dataclass
38+
class SamplerWarning:
39+
kind: WarningType
40+
message: str
41+
level: str
42+
step: Optional[int] = None
43+
exec_info: Optional[Any] = None
44+
extra: Optional[Any] = None
45+
divergence_point_source: Optional[dict] = None
46+
divergence_point_dest: Optional[dict] = None
47+
divergence_info: Optional[Any] = None
48+
49+
50+
def run_convergence_checks(idata: arviz.InferenceData, model) -> List[SamplerWarning]:
51+
if not hasattr(idata, "posterior"):
52+
msg = "No posterior samples. Unable to run convergence checks"
53+
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None)
54+
return [warn]
55+
56+
if idata["posterior"].sizes["chain"] == 1:
57+
msg = (
58+
"Only one chain was sampled, this makes it impossible to " "run some convergence checks"
59+
)
60+
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
61+
return [warn]
62+
63+
elif idata["posterior"].sizes["chain"] < 4:
64+
msg = (
65+
"We recommend running at least 4 chains for robust computation of "
66+
"convergence diagnostics"
67+
)
68+
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
69+
return [warn]
70+
71+
warnings = []
72+
valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
73+
varnames = []
74+
for rv in model.free_RVs:
75+
rv_name = rv.name
76+
if is_transformed_name(rv_name):
77+
rv_name2 = get_untransformed_name(rv_name)
78+
rv_name = rv_name2 if rv_name2 in valid_name else rv_name
79+
if rv_name in idata["posterior"]:
80+
varnames.append(rv_name)
81+
82+
ess = arviz.ess(idata, var_names=varnames)
83+
rhat = arviz.rhat(idata, var_names=varnames)
84+
85+
warnings = []
86+
rhat_max = max(val.max() for val in rhat.values())
87+
if rhat_max > 1.01:
88+
msg = (
89+
"The rhat statistic is larger than 1.01 for some "
90+
"parameters. This indicates problems during sampling. "
91+
"See https://arxiv.org/abs/1903.08008 for details"
92+
)
93+
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "info", extra=rhat)
94+
warnings.append(warn)
95+
96+
eff_min = min(val.min() for val in ess.values())
97+
eff_per_chain = eff_min / idata["posterior"].sizes["chain"]
98+
if eff_per_chain < 100:
99+
msg = (
100+
"The effective sample size per chain is smaller than 100 for some parameters. "
101+
" A higher number is needed for reliable rhat and ess computation. "
102+
"See https://arxiv.org/abs/1903.08008 for details"
103+
)
104+
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=ess)
105+
warnings.append(warn)
106+
107+
return warnings
108+
109+
110+
def log_warning(warn: SamplerWarning):
111+
level = _LEVELS[warn.level]
112+
logger.log(level, warn.message)
113+
114+
115+
def log_warnings(warnings: Sequence[SamplerWarning]):
116+
for warn in warnings:
117+
log_warning(warn)

pymc/step_methods/hmc/base_hmc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
import numpy as np
2222

2323
from pymc.aesaraf import floatX
24-
from pymc.backends.report import SamplerWarning, WarningType
2524
from pymc.blocking import DictToArrayBijection, RaveledVars
2625
from pymc.exceptions import SamplingError
2726
from pymc.model import Point, modelcontext
27+
from pymc.stats.convergence import SamplerWarning, WarningType
2828
from pymc.step_methods import step_sizes
2929
from pymc.step_methods.arraystep import GradientSharedStep
3030
from pymc.step_methods.hmc import integration

pymc/step_methods/hmc/nuts.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import numpy as np
1818

1919
from pymc.aesaraf import floatX
20-
from pymc.backends.report import SamplerWarning, WarningType
2120
from pymc.math import logbern
21+
from pymc.stats.convergence import SamplerWarning, WarningType
2222
from pymc.step_methods.arraystep import Competence
2323
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
2424
from pymc.step_methods.hmc.integration import IntegrationError

pymc/step_methods/step_sizes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from scipy import stats
1818

19-
from pymc.backends.report import SamplerWarning, WarningType
19+
from pymc.stats.convergence import SamplerWarning, WarningType
2020

2121

2222
class DualAverageAdaptation:

0 commit comments

Comments
 (0)