Skip to content

Commit 6833677

Browse files
authored
update reported warnings after sampling (#5516)
* update report * update per comments
1 parent 999160a commit 6833677

File tree

1 file changed

+18
-28
lines changed

1 file changed

+18
-28
lines changed

pymc/backends/report.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,15 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
124124
self._add_warnings([warn])
125125
return
126126

127+
elif idata.posterior.sizes["chain"] < 4:
128+
msg = (
129+
"We recommend running at least 4 chains for robust computation of "
130+
"convergence diagnostics"
131+
)
132+
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
133+
self._add_warnings([warn])
134+
return
135+
127136
valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
128137
varnames = []
129138
for rv in model.free_RVs:
@@ -139,44 +148,25 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
139148

140149
warnings = []
141150
rhat_max = max(val.max() for val in rhat.values())
142-
if rhat_max > 1.4:
151+
if rhat_max > 1.01:
143152
msg = (
144-
"The rhat statistic is larger than 1.4 for some "
145-
"parameters. The sampler did not converge."
146-
)
147-
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=rhat)
148-
warnings.append(warn)
149-
elif rhat_max > 1.2:
150-
msg = "The rhat statistic is larger than 1.2 for some " "parameters."
151-
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "warn", extra=rhat)
152-
warnings.append(warn)
153-
elif rhat_max > 1.05:
154-
msg = (
155-
"The rhat statistic is larger than 1.05 for some "
156-
"parameters. This indicates slight problems during "
157-
"sampling."
153+
"The rhat statistic is larger than 1.01 for some "
154+
"parameters. This indicates problems during sampling. "
155+
"See https://arxiv.org/abs/1903.08008 for details"
158156
)
159157
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "info", extra=rhat)
160158
warnings.append(warn)
161159

162160
eff_min = min(val.min() for val in ess.values())
163-
sizes = idata.posterior.sizes
164-
n_samples = sizes["chain"] * sizes["draw"]
165-
if eff_min < 200 and n_samples >= 500:
161+
eff_per_chain = eff_min / idata.posterior.sizes["chain"]
162+
if eff_per_chain < 100:
166163
msg = (
167-
"The estimated number of effective samples is smaller than "
168-
"200 for some parameters."
164+
"The effective sample size per chain is smaller than 100 for some parameters. "
165+
" A higher number is needed for reliable rhat and ess computation. "
166+
"See https://arxiv.org/abs/1903.08008 for details"
169167
)
170168
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=ess)
171169
warnings.append(warn)
172-
elif eff_min / n_samples < 0.1:
173-
msg = "The number of effective samples is smaller than " "10% for some parameters."
174-
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "warn", extra=ess)
175-
warnings.append(warn)
176-
elif eff_min / n_samples < 0.25:
177-
msg = "The number of effective samples is smaller than " "25% for some parameters."
178-
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "info", extra=ess)
179-
warnings.append(warn)
180170

181171
self._add_warnings(warnings)
182172

0 commit comments

Comments
 (0)