diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index dacec05e7..e3ea81f8f 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -11,6 +11,6 @@ dependencies: - xhistogram - statsmodels - pip: - - pymc>=5.8.1 # CI was failing to resolve + - pymc>=5.9.0 # CI was failing to resolve - blackjax - scikit-learn diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index c7cef7dc2..881ba51c7 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -10,5 +10,5 @@ dependencies: - xhistogram - statsmodels - pip: - - pymc>=5.8.1 # CI was failing to resolve + - pymc>=5.9.0 # CI was failing to resolve - scikit-learn diff --git a/pymc_experimental/inference/fit.py b/pymc_experimental/inference/fit.py index 6b9835c04..565b59e3d 100644 --- a/pymc_experimental/inference/fit.py +++ b/pymc_experimental/inference/fit.py @@ -31,7 +31,10 @@ def fit(method, **kwargs): """ if method == "pathfinder": try: - from pymc_experimental.inference.pathfinder import fit_pathfinder + import blackjax except ImportError as exc: raise RuntimeError("Need BlackJAX to use `pathfinder`") from exc + + from pymc_experimental.inference.pathfinder import fit_pathfinder + return fit_pathfinder(**kwargs) diff --git a/pymc_experimental/inference/pathfinder.py b/pymc_experimental/inference/pathfinder.py index 94a0bb32f..2063e9509 100644 --- a/pymc_experimental/inference/pathfinder.py +++ b/pymc_experimental/inference/pathfinder.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import collections import sys from typing import Optional @@ -20,31 +19,30 @@ import arviz as az import blackjax import jax -import jax.numpy as jnp -import jax.random as random import numpy as np import pymc as pm -from pymc import modelcontext +from packaging import version +from pymc.backends.arviz import coords_and_dims_for_inferencedata +from pymc.blocking import DictToArrayBijection, RaveledVars +from pymc.model import modelcontext from pymc.sampling.jax import get_jaxified_graph from pymc.util import RandomSeed, _get_seeds_per_chain, get_default_varnames def convert_flat_trace_to_idata( samples, - dims=None, - coords=None, include_transformed=False, postprocessing_backend="cpu", model=None, ): model = modelcontext(model) - init_position_dict = model.initial_point() + ip = model.initial_point() + ip_point_map_info = pm.blocking.DictToArrayBijection.map(ip).point_map_info trace = collections.defaultdict(list) - astart = pm.blocking.DictToArrayBijection.map(init_position_dict) for sample in samples: - raveld_vars = pm.blocking.RaveledVars(sample, astart.point_map_info) - point = pm.blocking.DictToArrayBijection.rmap(raveld_vars, init_position_dict) + raveld_vars = RaveledVars(sample, ip_point_map_info) + point = DictToArrayBijection.rmap(raveld_vars, ip) for p, v in point.items(): trace[p].append(v.tolist()) @@ -57,19 +55,19 @@ def convert_flat_trace_to_idata( result = jax.vmap(jax.vmap(jax_fn))( *jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0]) ) - trace = {v.name: r for v, r in zip(vars_to_sample, result)} + coords, dims = coords_and_dims_for_inferencedata(model) idata = az.from_dict(trace, dims=dims, coords=coords) return idata def fit_pathfinder( - iterations=5_000, + samples=1000, random_seed: Optional[RandomSeed] = None, postprocessing_backend="cpu", - ftol=1e-4, model=None, + **pathfinder_kwargs, ): """ Fit the pathfinder algorithm as implemented in blackjax @@ -78,15 +76,15 @@ def fit_pathfinder( Parameters ---------- - iterations : int - Number of iterations to run. + samples : int + Number of samples to draw from the fitted approximation. random_seed : int Random seed to set. postprocessing_backend : str Where to compute transformations of the trace. "cpu" or "gpu". - ftol : float - Floating point tolerance + pathfinder_kwargs: + kwargs for blackjax.vi.pathfinder.approximate Returns ------- @@ -96,17 +94,17 @@ def fit_pathfinder( --------- https://arxiv.org/abs/2108.03782 """ - - (random_seed,) = _get_seeds_per_chain(random_seed, 1) + # Temporarily helper + if version.parse(blackjax.__version__).major < 1: + raise ImportError("fit_pathfinder requires blackjax 1.0 or above") model = modelcontext(model) - rvs = [rv.name for rv in model.value_vars] - init_position_dict = model.initial_point() - init_position = [init_position_dict[rv] for rv in rvs] + ip = model.initial_point() + ip_map = DictToArrayBijection.map(ip) new_logprob, new_input = pm.pytensorf.join_nonshared_inputs( - init_position_dict, (model.logp(),), model.value_vars, () + ip, (model.logp(),), model.value_vars, () ) logprob_fn_list = get_jaxified_graph([new_input], new_logprob) @@ -114,35 +112,24 @@ def fit_pathfinder( def logprob_fn(x): return logprob_fn_list(x)[0] - dim = sum(v.size for v in init_position_dict.values()) - - rng_key = random.PRNGKey(random_seed) - w0 = random.multivariate_normal(rng_key, 2.0 + jnp.zeros(dim), jnp.eye(dim)) - path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=ftol) - - pathfinder = blackjax.kernels.pathfinder(rng_key, logprob_fn, ftol=ftol) - state = pathfinder.init(w0) - - def inference_loop(rng_key, kernel, initial_state, num_samples): - @jax.jit - def one_step(state, rng_key): - state, info = kernel(rng_key, state) - return state, (state, info) + [pathfinder_seed, sample_seed] = _get_seeds_per_chain(random_seed, 2) - keys = jax.random.split(rng_key, num_samples) - return jax.lax.scan(one_step, initial_state, keys) - - _, rng_key = random.split(rng_key) print("Running pathfinder...", file=sys.stdout) - _, (_, samples) = inference_loop(rng_key, pathfinder.step, state, iterations) - - dims = { - var_name: [dim for dim in dims if dim is not None] - for var_name, dims in model.named_vars_to_dims.items() - } + pathfinder_state, _ = blackjax.vi.pathfinder.approximate( + rng_key=jax.random.key(pathfinder_seed), + logdensity_fn=logprob_fn, + initial_position=ip_map.data, + **pathfinder_kwargs, + ) + samples, _ = blackjax.vi.pathfinder.sample( + rng_key=jax.random.key(sample_seed), + state=pathfinder_state, + num_samples=samples, + ) idata = convert_flat_trace_to_idata( - samples, postprocessing_backend=postprocessing_backend, coords=model.coords, dims=dims + samples, + postprocessing_backend=postprocessing_backend, + model=model, ) - return idata diff --git a/pymc_experimental/tests/test_pathfinder.py b/pymc_experimental/tests/test_pathfinder.py index 1be2a22c9..7dacdb841 100644 --- a/pymc_experimental/tests/test_pathfinder.py +++ b/pymc_experimental/tests/test_pathfinder.py @@ -21,12 +21,7 @@ import pymc_experimental as pmx -# TODO: Remove this filterwarning after pytensor uses jnp.prod instead of jnp.product @pytest.mark.skipif(sys.platform == "win32", reason="JAX not supported on windows.") -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="pymc.sampling.jax does not currently support python < 3.10" -) -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_pathfinder(): # Data of the Eight Schools Model J = 8 @@ -41,12 +36,11 @@ def test_pathfinder(): theta = pm.Normal("theta", mu=0, sigma=1, shape=J) obs = pm.Normal("obs", mu=mu + tau * theta, sigma=sigma, shape=J, observed=y) - idata = pmx.fit(method="pathfinder", iterations=100) + idata = pmx.fit(method="pathfinder", random_seed=41) - assert idata is not None - assert "theta" in idata.posterior._variables.keys() - assert "tau" in idata.posterior._variables.keys() - assert "mu" in idata.posterior._variables.keys() - assert idata.posterior["mu"].shape == (1, 100) - assert idata.posterior["tau"].shape == (1, 100) - assert idata.posterior["theta"].shape == (1, 100, 8) + assert idata.posterior["mu"].shape == (1, 1000) + assert idata.posterior["tau"].shape == (1, 1000) + assert idata.posterior["theta"].shape == (1, 1000, 8) + # FIXME: pathfinder doesn't find a reasonable mean! Fix bug or choose model pathfinder can handle + # np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0) + np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=0.5) diff --git a/requirements.txt b/requirements.txt index e9de9d758..2609234ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -pymc>=5.8.1 +pymc>=5.8.2 scikit-learn