-
Notifications
You must be signed in to change notification settings - Fork 69
Closed
Description
Running pymc_experimental
0.0.12 and blackjax
1.0.0
MRE:
import pymc as pm
import pymc_experimental as pmx
import numpy as np
with pm.Model():
x = pm.Normal('x')
obs = pm.Normal('obs', mu=x, sigma=1, observed=np.random.normal(loc=3, size=(100,)))
idata = pmx.fit('pathfinder')
Raises:
AttributeError Traceback (most recent call last)
Cell In[15], line 4
2 x = pm.Normal('x')
3 obs = pm.Normal('obs', mu=x, sigma=1, observed=np.random.normal(loc=3, size=(100,)))
----> 4 idata = pmx.fit('pathfinder')
File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pymc_experimental/inference/fit.py:37, in fit(method, **kwargs)
35 except ImportError as exc:
36 raise RuntimeError("Need BlackJAX to use `pathfinder`") from exc
---> 37 return fit_pathfinder(**kwargs)
File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pymc_experimental/inference/pathfinder.py:121, in fit_pathfinder(iterations, random_seed, postprocessing_backend, ftol, model)
119 rng_key = random.PRNGKey(random_seed)
120 w0 = random.multivariate_normal(rng_key, 2.0 + jnp.zeros(dim), jnp.eye(dim))
--> 121 path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=ftol)
123 pathfinder = blackjax.kernels.pathfinder(rng_key, logprob_fn, ftol=ftol)
124 state = pathfinder.init(w0)
AttributeError: module 'blackjax.vi.pathfinder' has no attribute 'init'
@junpenglao Did the blackjax pathfinder API change? I tried to check the git history over there but didn't see anything obvious.
Metadata
Metadata
Assignees
Labels
No labels