13
13
# limitations under the License.
14
14
15
15
import dataclasses
16
- import enum
17
16
import logging
18
17
19
- from typing import Any , Optional
18
+ from typing import Optional
20
19
21
20
import arviz
22
21
23
- from pymc .util import get_untransformed_name , is_transformed_name
22
+ from pymc .stats .convergence import (
23
+ _LEVELS ,
24
+ SamplerWarning ,
25
+ log_warnings ,
26
+ run_convergence_checks ,
27
+ )
24
28
25
29
logger = logging .getLogger ("pymc" )
26
30
27
31
28
- @enum .unique
29
- class WarningType (enum .Enum ):
30
- # For HMC and NUTS
31
- DIVERGENCE = 1
32
- TUNING_DIVERGENCE = 2
33
- DIVERGENCES = 3
34
- TREEDEPTH = 4
35
- # Problematic sampler parameters
36
- BAD_PARAMS = 5
37
- # Indications that chains did not converge, eg Rhat
38
- CONVERGENCE = 6
39
- BAD_ACCEPTANCE = 7
40
- BAD_ENERGY = 8
41
-
42
-
43
- @dataclasses .dataclass
44
- class SamplerWarning :
45
- kind : WarningType
46
- message : str
47
- level : str
48
- step : Optional [int ] = None
49
- exec_info : Optional [Any ] = None
50
- extra : Optional [Any ] = None
51
- divergence_point_source : Optional [dict ] = None
52
- divergence_point_dest : Optional [dict ] = None
53
- divergence_info : Optional [Any ] = None
54
-
55
-
56
- _LEVELS = {
57
- "info" : logging .INFO ,
58
- "error" : logging .ERROR ,
59
- "warn" : logging .WARN ,
60
- "debug" : logging .DEBUG ,
61
- "critical" : logging .CRITICAL ,
62
- }
63
-
64
-
65
32
class SamplerReport :
66
33
"""Bundle warnings, convergence stats and metadata of a sampling run."""
67
34
68
35
def __init__ (self ):
69
- self ._chain_warnings = {}
70
- self ._global_warnings = []
71
- self ._ess = None
72
- self ._rhat = None
36
+ self ._chain_warnings : Dict [int , List [SamplerWarning ]] = {}
37
+ self ._global_warnings : List [SamplerWarning ] = []
73
38
self ._n_tune = None
74
39
self ._n_draws = None
75
40
self ._t_sampling = None
@@ -109,65 +74,7 @@ def raise_ok(self, level="error"):
109
74
raise ValueError ("Serious convergence issues during sampling." )
110
75
111
76
def _run_convergence_checks (self , idata : arviz .InferenceData , model ):
112
- if not hasattr (idata , "posterior" ):
113
- msg = "No posterior samples. Unable to run convergence checks"
114
- warn = SamplerWarning (WarningType .BAD_PARAMS , msg , "info" , None , None , None )
115
- self ._add_warnings ([warn ])
116
- return
117
-
118
- if idata ["posterior" ].sizes ["chain" ] == 1 :
119
- msg = (
120
- "Only one chain was sampled, this makes it impossible to "
121
- "run some convergence checks"
122
- )
123
- warn = SamplerWarning (WarningType .BAD_PARAMS , msg , "info" )
124
- self ._add_warnings ([warn ])
125
- return
126
-
127
- elif idata ["posterior" ].sizes ["chain" ] < 4 :
128
- msg = (
129
- "We recommend running at least 4 chains for robust computation of "
130
- "convergence diagnostics"
131
- )
132
- warn = SamplerWarning (WarningType .BAD_PARAMS , msg , "info" )
133
- self ._add_warnings ([warn ])
134
- return
135
-
136
- valid_name = [rv .name for rv in model .free_RVs + model .deterministics ]
137
- varnames = []
138
- for rv in model .free_RVs :
139
- rv_name = rv .name
140
- if is_transformed_name (rv_name ):
141
- rv_name2 = get_untransformed_name (rv_name )
142
- rv_name = rv_name2 if rv_name2 in valid_name else rv_name
143
- if rv_name in idata ["posterior" ]:
144
- varnames .append (rv_name )
145
-
146
- self ._ess = ess = arviz .ess (idata , var_names = varnames )
147
- self ._rhat = rhat = arviz .rhat (idata , var_names = varnames )
148
-
149
- warnings = []
150
- rhat_max = max (val .max () for val in rhat .values ())
151
- if rhat_max > 1.01 :
152
- msg = (
153
- "The rhat statistic is larger than 1.01 for some "
154
- "parameters. This indicates problems during sampling. "
155
- "See https://arxiv.org/abs/1903.08008 for details"
156
- )
157
- warn = SamplerWarning (WarningType .CONVERGENCE , msg , "info" , extra = rhat )
158
- warnings .append (warn )
159
-
160
- eff_min = min (val .min () for val in ess .values ())
161
- eff_per_chain = eff_min / idata ["posterior" ].sizes ["chain" ]
162
- if eff_per_chain < 100 :
163
- msg = (
164
- "The effective sample size per chain is smaller than 100 for some parameters. "
165
- " A higher number is needed for reliable rhat and ess computation. "
166
- "See https://arxiv.org/abs/1903.08008 for details"
167
- )
168
- warn = SamplerWarning (WarningType .CONVERGENCE , msg , "error" , extra = ess )
169
- warnings .append (warn )
170
-
77
+ warnings = run_convergence_checks (idata , model )
171
78
self ._add_warnings (warnings )
172
79
173
80
def _add_warnings (self , warnings , chain = None ):
@@ -178,15 +85,9 @@ def _add_warnings(self, warnings, chain=None):
178
85
warn_list .extend (warnings )
179
86
180
87
def _log_summary (self ):
181
- def log_warning (warn ):
182
- level = _LEVELS [warn .level ]
183
- logger .log (level , warn .message )
184
-
185
88
for chain , warns in self ._chain_warnings .items ():
186
- for warn in warns :
187
- log_warning (warn )
188
- for warn in self ._global_warnings :
189
- log_warning (warn )
89
+ log_warnings (warns )
90
+ log_warnings (self ._global_warnings )
190
91
191
92
def _slice (self , start , stop , step ):
192
93
report = SamplerReport ()
0 commit comments