Skip to content

Commit d5fcbb0

Browse files
committed
Merge branch 'main' into adding-blackjax-support
2 parents c2a5ea7 + 44c5495 commit d5fcbb0

28 files changed

+1427
-2599
lines changed

.github/workflows/nightly.yml

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
name: nightly
2+
3+
on:
4+
schedule:
5+
- cron: "0 0 * * *"
6+
7+
jobs:
8+
build-and-publish-nightly:
9+
name: Build source distribution
10+
runs-on: ubuntu-latest
11+
steps:
12+
- uses: actions/checkout@v2
13+
with:
14+
fetch-depth: 0
15+
- uses: actions/setup-python@v2
16+
with:
17+
python-version: 3.9
18+
- name: Install dependencies
19+
run: |
20+
python -m pip install -U pip
21+
python -m pip install build
22+
- name: Build the sdist
23+
run: python -m build --sdist .
24+
env:
25+
BUILD_PYMC_NIGHTLY: true
26+
- name: Publish to PyPI
27+
uses: pypa/[email protected]
28+
with:
29+
user: __token__
30+
password: ${{ secrets.PYPI_TOKEN_PYMC_NIGHTLY }}
31+
test-install-job:
32+
needs: build-and-publish-nightly
33+
runs-on: ubuntu-latest
34+
steps:
35+
- name: Set up Python
36+
uses: actions/setup-python@v2
37+
with:
38+
python-version: 3.9
39+
- name: Give PyPI a chance to update the index
40+
run: sleep 240
41+
- name: Install from PyPI
42+
run: |
43+
pip install pymc-nightly==$(grep 'version' pymc/__init__.py | awk '{print $3}' | tr -d '"').dev$(date +"%Y%m%d")

.github/workflows/pytest.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ jobs:
6363
--ignore=pymc/tests/test_distributions_random.py
6464
--ignore=pymc/tests/test_idata_conversion.py
6565
--ignore=pymc/tests/test_smc.py
66-
--ignore=pymc/tests/test_bart.py
6766
--ignore=pymc/tests/test_missing.py
6867
6968
- |
@@ -77,7 +76,7 @@ jobs:
7776
pymc/tests/test_updates.py
7877
pymc/tests/test_transforms.py
7978
pymc/tests/test_smc.py
80-
pymc/tests/test_bart.py
79+
pymc/tests/test_mixture.py
8180
8281
- |
8382
pymc/tests/test_parallel_sampling.py

RELEASE-NOTES.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,13 @@ Instead update the vNext section until 4.0.0 is out.
77
⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠⚠
88
-->
99

10-
## PyMC vNext (4.0.0b1 → 4.0.0b2 → 4.0.0b3 → 4.0.0)
10+
## PyMC vNext (4.0.0b1 → 4.0.0b2 → 4.0.0b3 → 4.0.0b4 → 4.0.0)
1111
⚠ The changes below are the delta between the upcoming releases `v3.11.5` →...→ `v4.0.0`.
1212

13-
### No-yet working features
13+
### Not-yet working features
1414
We plan to get these working again, but at this point their inner workings have not been refactored.
1515
- Timeseries distributions (see [#4642](https://github.com/pymc-devs/pymc/issues/4642))
16-
- Mixture distributions (see [#4781](https://github.com/pymc-devs/pymc/issues/4781))
17-
- Cholesky distributions (see WIP PR [#4784](https://github.com/pymc-devs/pymc/pull/4784))
18-
- Variational inference submodule (see WIP PR [#4582](https://github.com/pymc-devs/pymc/pull/4582))
16+
- Nested Mixture distributions (see [#5533](https://github.com/pymc-devs/pymc/issues/5533))
1917
- Elliptical slice sampling (see [#5137](https://github.com/pymc-devs/pymc/issues/5137))
2018
- `BaseStochasticGradient` (see [#5138](https://github.com/pymc-devs/pymc/issues/5138))
2119
- `pm.sample_posterior_predictive_w` (see [#4807](https://github.com/pymc-devs/pymc/issues/4807))
@@ -74,6 +72,7 @@ All of the above apply to:
7472
- In the gp.utils file, the `kmeans_inducing_points` function now passes through `kmeans_kwargs` to scipy's k-means function.
7573
- The function `replace_with_values` function has been added to `gp.utils`.
7674
- `MarginalSparse` has been renamed `MarginalApprox`.
75+
- Removed `MixtureSameFamily`. `Mixture` is now capable of handling batched multivariate components (see [#5438](https://github.com/pymc-devs/pymc/pull/5438)).
7776
- ...
7877

7978
### Expected breaks
@@ -129,6 +128,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
129128
- `softmax` and `log_softmax` functions added to `math` module (see [#5279](https://github.com/pymc-devs/pymc/pull/5279)).
130129
- ...
131130

131+
132132
## Documentation
133133
- Switched to the [pydata-sphinx-theme](https://pydata-sphinx-theme.readthedocs.io/en/latest/)
134134
- Updated our documentation tooling to use [MyST](https://myst-parser.readthedocs.io/en/latest/), [MyST-NB](https://myst-nb.readthedocs.io/en/latest/), sphinx-design, notfound.extension,

docs/source/api.rst

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ API Reference
1515
api/smc
1616
api/backends
1717
api/data
18-
api/bart
1918
api/ode
2019
api/tuning
2120
api/math

docs/source/api/bart.rst

-12
This file was deleted.

docs/source/api/distributions/mixture.rst

-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@ Mixture
88

99
Mixture
1010
NormalMixture
11-
MixtureSameFamily

pymc/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# pylint: disable=wildcard-import
16-
__version__ = "4.0.0b2"
16+
__version__ = "4.0.0b3"
1717

1818
import logging
1919
import multiprocessing as mp
@@ -52,7 +52,6 @@ def __set_compiler_flags():
5252
from pymc import gp, ode, sampling
5353
from pymc.aesaraf import *
5454
from pymc.backends import *
55-
from pymc.bart import *
5655
from pymc.blocking import *
5756
from pymc.data import *
5857
from pymc.distributions import *

pymc/backends/report.py

+18-28
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,15 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
124124
self._add_warnings([warn])
125125
return
126126

127+
elif idata.posterior.sizes["chain"] < 4:
128+
msg = (
129+
"We recommend running at least 4 chains for robust computation of "
130+
"convergence diagnostics"
131+
)
132+
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
133+
self._add_warnings([warn])
134+
return
135+
127136
valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
128137
varnames = []
129138
for rv in model.free_RVs:
@@ -139,44 +148,25 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
139148

140149
warnings = []
141150
rhat_max = max(val.max() for val in rhat.values())
142-
if rhat_max > 1.4:
151+
if rhat_max > 1.01:
143152
msg = (
144-
"The rhat statistic is larger than 1.4 for some "
145-
"parameters. The sampler did not converge."
146-
)
147-
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=rhat)
148-
warnings.append(warn)
149-
elif rhat_max > 1.2:
150-
msg = "The rhat statistic is larger than 1.2 for some " "parameters."
151-
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "warn", extra=rhat)
152-
warnings.append(warn)
153-
elif rhat_max > 1.05:
154-
msg = (
155-
"The rhat statistic is larger than 1.05 for some "
156-
"parameters. This indicates slight problems during "
157-
"sampling."
153+
"The rhat statistic is larger than 1.01 for some "
154+
"parameters. This indicates problems during sampling. "
155+
"See https://arxiv.org/abs/1903.08008 for details"
158156
)
159157
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "info", extra=rhat)
160158
warnings.append(warn)
161159

162160
eff_min = min(val.min() for val in ess.values())
163-
sizes = idata.posterior.sizes
164-
n_samples = sizes["chain"] * sizes["draw"]
165-
if eff_min < 200 and n_samples >= 500:
161+
eff_per_chain = eff_min / idata.posterior.sizes["chain"]
162+
if eff_per_chain < 100:
166163
msg = (
167-
"The estimated number of effective samples is smaller than "
168-
"200 for some parameters."
164+
"The effective sample size per chain is smaller than 100 for some parameters. "
165+
" A higher number is needed for reliable rhat and ess computation. "
166+
"See https://arxiv.org/abs/1903.08008 for details"
169167
)
170168
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=ess)
171169
warnings.append(warn)
172-
elif eff_min / n_samples < 0.1:
173-
msg = "The number of effective samples is smaller than " "10% for some parameters."
174-
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "warn", extra=ess)
175-
warnings.append(warn)
176-
elif eff_min / n_samples < 0.25:
177-
msg = "The number of effective samples is smaller than " "25% for some parameters."
178-
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "info", extra=ess)
179-
warnings.append(warn)
180170

181171
self._add_warnings(warnings)
182172

pymc/bart/__init__.py

-20
This file was deleted.

pymc/bart/bart.py

-155
This file was deleted.

0 commit comments

Comments
 (0)