@@ -124,6 +124,15 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
124
124
self ._add_warnings ([warn ])
125
125
return
126
126
127
+ elif idata .posterior .sizes ["chain" ] < 4 :
128
+ msg = (
129
+ "We recommend running at least 4 chains for robust computation of "
130
+ "convergence diagnostics"
131
+ )
132
+ warn = SamplerWarning (WarningType .BAD_PARAMS , msg , "info" )
133
+ self ._add_warnings ([warn ])
134
+ return
135
+
127
136
valid_name = [rv .name for rv in model .free_RVs + model .deterministics ]
128
137
varnames = []
129
138
for rv in model .free_RVs :
@@ -139,44 +148,25 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
139
148
140
149
warnings = []
141
150
rhat_max = max (val .max () for val in rhat .values ())
142
- if rhat_max > 1.4 :
151
+ if rhat_max > 1.01 :
143
152
msg = (
144
- "The rhat statistic is larger than 1.4 for some "
145
- "parameters. The sampler did not converge."
146
- )
147
- warn = SamplerWarning (WarningType .CONVERGENCE , msg , "error" , extra = rhat )
148
- warnings .append (warn )
149
- elif rhat_max > 1.2 :
150
- msg = "The rhat statistic is larger than 1.2 for some " "parameters."
151
- warn = SamplerWarning (WarningType .CONVERGENCE , msg , "warn" , extra = rhat )
152
- warnings .append (warn )
153
- elif rhat_max > 1.05 :
154
- msg = (
155
- "The rhat statistic is larger than 1.05 for some "
156
- "parameters. This indicates slight problems during "
157
- "sampling."
153
+ "The rhat statistic is larger than 1.01 for some "
154
+ "parameters. This indicates problems during sampling. "
155
+ "See https://arxiv.org/abs/1903.08008 for details"
158
156
)
159
157
warn = SamplerWarning (WarningType .CONVERGENCE , msg , "info" , extra = rhat )
160
158
warnings .append (warn )
161
159
162
160
eff_min = min (val .min () for val in ess .values ())
163
- sizes = idata .posterior .sizes
164
- n_samples = sizes ["chain" ] * sizes ["draw" ]
165
- if eff_min < 200 and n_samples >= 500 :
161
+ eff_per_chain = eff_min / idata .posterior .sizes ["chain" ]
162
+ if eff_per_chain < 100 :
166
163
msg = (
167
- "The estimated number of effective samples is smaller than "
168
- "200 for some parameters."
164
+ "The effective sample size per chain is smaller than 100 for some parameters. "
165
+ " A higher number is needed for reliable rhat and ess computation. "
166
+ "See https://arxiv.org/abs/1903.08008 for details"
169
167
)
170
168
warn = SamplerWarning (WarningType .CONVERGENCE , msg , "error" , extra = ess )
171
169
warnings .append (warn )
172
- elif eff_min / n_samples < 0.1 :
173
- msg = "The number of effective samples is smaller than " "10% for some parameters."
174
- warn = SamplerWarning (WarningType .CONVERGENCE , msg , "warn" , extra = ess )
175
- warnings .append (warn )
176
- elif eff_min / n_samples < 0.25 :
177
- msg = "The number of effective samples is smaller than " "25% for some parameters."
178
- warn = SamplerWarning (WarningType .CONVERGENCE , msg , "info" , extra = ess )
179
- warnings .append (warn )
180
170
181
171
self ._add_warnings (warnings )
182
172
0 commit comments