From 65f9d432decfdcbbb3781fdff3e4d152b223cd94 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Wed, 19 Feb 2025 12:10:48 +0800 Subject: [PATCH 1/4] Use new `method` argument for `MvNormal` to defined `MvNormalSVD` --- .../statespace/filters/distributions.py | 25 ++----------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/pymc_extras/statespace/filters/distributions.py b/pymc_extras/statespace/filters/distributions.py index d3b70c847..3d0ed44d6 100644 --- a/pymc_extras/statespace/filters/distributions.py +++ b/pymc_extras/statespace/filters/distributions.py @@ -62,29 +62,8 @@ class MvNormalSVD(MvNormal): A JAX MvNormal robust to low-rank covariance matrices """ - rv_op = MvNormalSVDRV() - - -try: - import jax.random - - from pytensor.link.jax.dispatch.random import jax_sample_fn - - @jax_sample_fn.register(MvNormalSVDRV) - def jax_sample_fn_mvnormal_svd(op, node): - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) - sample = jax.random.multivariate_normal( - sampling_key, *parameters, shape=size, dtype=dtype, method="svd" - ) - rng["jax_state"] = rng_key - return (rng, sample) - - return sample_fn - -except ImportError: - pass + # TODO: Remove this entirely on next PyMC release; method will be exposed directly in MvNormal + rv_op = MvNormalSVDRV(method="svd") class LinearGaussianStateSpaceRV(SymbolicRandomVariable): From c5194c0f7d67cdf81f285f89ff4b22e10f474bd4 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 1 Mar 2025 16:00:51 +0800 Subject: [PATCH 2/4] Can't use `if Variable` anymore --- pymc_extras/statespace/core/statespace.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 2590dd53a..d5d131d80 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -707,7 +707,7 @@ def _insert_random_variables(self): with pymc_model: for param_name in self.param_names: param = getattr(pymc_model, param_name, None) - if param: + if param is not None: found_params.append(param.name) missing_params = list(set(self.param_names) - set(found_params)) @@ -746,7 +746,7 @@ def _insert_data_variables(self): with pymc_model: for data_name in data_names: data = getattr(pymc_model, data_name, None) - if data: + if data is not None: found_data.append(data.name) missing_data = list(set(data_names) - set(found_data)) From 907b083a7c9aea3e8fe138c9ef0eab3bc375c3a7 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 1 Mar 2025 16:05:32 +0800 Subject: [PATCH 3/4] Bump PyMC version pin --- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index a1aa85c7e..450b46e30 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -3,7 +3,7 @@ channels: - conda-forge - nodefaults dependencies: -- pymc>=5.20 +- pymc>=5.21 - pytest-cov>=2.5 - pytest>=3.0 - dask diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 1d1eb7745..d2a3e8934 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -10,7 +10,7 @@ dependencies: - xhistogram - statsmodels - numba<=0.60.0 -- pymc>=5.20 +- pymc>=5.21 - pip: - blackjax - scikit-learn diff --git a/requirements.txt b/requirements.txt index a4f00ee21..3a1f85ac8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -pymc>=5.20 +pymc>=5.21 scikit-learn better-optimize From 34c66addff9a1ab638aa3ee2c1542aef86b776ac Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 1 Mar 2025 22:37:55 +0800 Subject: [PATCH 4/4] Ignore numpy warning from pymc --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 17187a524..e16b3b389 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,9 @@ filterwarnings =[ # Warning coming from blackjax 'ignore:jax\.tree_map is deprecated:DeprecationWarning', + + # Ignore PyMC use of numpy.core + 'ignore:numpy\.core\.numeric is deprecated:DeprecationWarning', ] [tool.coverage.report]