-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Adding NUTS sampler from blackjax to sampling_jax #5477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Although not mentioned @junpenglao and @rlouf helped a lot with getting this to work from the blackjax side. |
Codecov Report
@@ Coverage Diff @@
## main #5477 +/- ##
==========================================
+ Coverage 87.59% 87.64% +0.05%
==========================================
Files 76 76
Lines 13694 13765 +71
==========================================
+ Hits 11995 12065 +70
- Misses 1699 1700 +1
|
Great! We can add a progress with callbacks, as is done in MCX. Afair the overhead it introduces is very small. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great. I'll have another look later, but left a comment above
pymc/sampling_jax.py
Outdated
model_logpt = model.logpt() | ||
if not negative_logp: | ||
model_logpt = -model_logpt | ||
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt]) | ||
|
||
def logp_fn_wrap(x): | ||
# NumPyro expects a scalar potential with the opposite sign of model.logpt | ||
res = logp_fn(*x)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's return this directly, instead of saving it intermediately in res
return inference_loop(seed, last_state) | ||
|
||
|
||
def sample_blackjax_nuts( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function needs to be documented and added to the API docs (where the numpyro one might also be missing?). It probably makes more sense to add in the first block of https://github.com/pymc-devs/pymc/blob/main/docs/source/api/samplers.rst.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit rusty with sphinx, where do I specify that the functions are in sampling_jax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by that point currentmodule
is set to pymc, so sphinx assumes functions live there. If they are in pymc.sampling_jax.<function>
you'll need to use sampling_jax.<function>
as the pymc is assumed because of the currentmodule use
pymc/sampling_jax.py
Outdated
observed_data=find_observations(model), | ||
coords=coords, | ||
dims=dims, | ||
attrs={"sampling_time": (tic3 - tic2).total_seconds()}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also add pymc and blackjax versions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure thing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh is there an arviz convention for when there are multiple libraries? This came up in earlier discussions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not that I'm aware of 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to put in blackjax and numpyro as the inference library, and add pymc_version as another field.
docs/source/api/samplers.rst
Outdated
@@ -30,6 +30,8 @@ HMC family | |||
|
|||
NUTS | |||
HamiltonianMC | |||
sampling_jax.sample_blackjax_nuts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think they should be in the first block (or in a different subsection), IIUC, both are nuts only even if blackjax has other samplers, but they are not step methods that can be used as inputs to pm.sample
@@ -188,6 +188,45 @@ def sample_blackjax_nuts( | |||
chain_method="parallel", | |||
idata_kwargs=None, | |||
): | |||
"""Draw samples from the posterior using the NUTS method from the blackjax library. | |||
|
|||
Parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update with the advise on https://pymc-data-umbrella.xyz/en/latest/sprint/docstring_tutorial.html#edit-the-docstring and I'll take a 2nd look later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I must be missing something but I think I'm style compliant. I actually copied the docstrings from sample
as my starting point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically none of the docstrings are currently style compliant, and most aren't rendered correctly either because of that. We have an issue open for this: #5459. I can give a quick go at one of the docstrings. I think you'll also need to rebase on main for the docs preview to work and CI to pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added some comments, but they are not exaustive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh ok it was how I'm specifying defaults and optional.
pymc/sampling_jax.py
Outdated
draws : int | ||
The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded | ||
by default. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
draws : int | |
The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded | |
by default. | |
draws : int, default 1000 | |
The number of samples to draw. The number of tuned samples are discarded | |
by default. |
pymc/sampling_jax.py
Outdated
model : Model (optional if in ``with`` context) | ||
Model to sample from. The model needs to have free random variables. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model : Model (optional if in ``with`` context) | |
Model to sample from. The model needs to have free random variables. | |
model : Model, optional | |
Model to sample from. The model needs to have free random variables. When inside a ``with`` model | |
context, it defaults to that model, otherwise the model must be passed explicitly. |
pymc/sampling_jax.py
Outdated
var_names : Iterable[str] | ||
Names of variables for which to compute the posterior samples. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
var_names : Iterable[str] | |
Names of variables for which to compute the posterior samples. | |
var_names : iterable of str, optional | |
Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior |
pymc/sampling_jax.py
Outdated
Model to sample from. The model needs to have free random variables. | ||
var_names : Iterable[str] | ||
Names of variables for which to compute the posterior samples. | ||
progress_bar : bool, optional default=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
progress_bar : bool, optional default=True | |
progress_bar : bool, default True |
pymc/sampling_jax.py
Outdated
Specify how samples should be drawn. The choices include "parallel", and | ||
"vectorized". Defaults to "parallel". | ||
idata_kwargs : dict, optional | ||
Keyword arguments for :func:`pymc.to_inference_data` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keyword arguments for :func:`pymc.to_inference_data` | |
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value for the ``log_likelihood`` | |
key to indicate that the pointwise log likelihood should not be included in the returned object. |
pymc/sampling_jax.py
Outdated
trace : arviz.InferenceData | ||
ArviZ ``InferenceData`` object that contains the samples. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trace : arviz.InferenceData | |
ArviZ ``InferenceData`` object that contains the samples. | |
InferenceData | |
ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and | |
pointwise log likeihood values (unless skipped with ``idata_kwargs``). |
pymc/sampling_jax.py
Outdated
tune : int | ||
Number of iterations to tune, defaults to 1000. Samplers adjust the step sizes, scalings or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tune : int | |
Number of iterations to tune, defaults to 1000. Samplers adjust the step sizes, scalings or | |
tune : int, default 1000 | |
Number of iterations to tune. Samplers adjust the step sizes, scalings or |
e1a636f
to
8781b32
Compare
8781b32
to
8a27c1f
Compare
The functions are still missing from docs (https://pymc--5477.org.readthedocs.build/en/5477/api/samplers.html) due to this issue:
what do we think is a better option, should jax be imported in a try except or installed in the docs env? |
I'm already installing |
yep, both should fix the issue. Sphinx imports the objects to be documented in order to extract the |
As the name implies, |
Aren't the two almost identical except for compilers and fixing blas?
…On Sat, Feb 19, 2022, 13:31 Ricardo Vieira ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In conda-envs/environment-dev-py37.yml
<#5477 (comment)>:
> @@ -13,6 +13,7 @@ dependencies:
- fastprogress>=0.2.0
- h5py>=2.7
- ipython>=7.16
+- jax
I think docs requires quite more dependencies than tests
—
Reply to this email directly, view it on GitHub
<#5477 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAFETGD4JPUWRWQSCIQPCWTU36EQPANCNFSM5OQKUNPQ>
.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
You are receiving this because you commented.Message ID:
***@***.***>
|
assuming this is in reference to test and docs environments. Both have the common step of installing pymc but that is the only thing they have in common. Test then needs to execute and test virtually all the functions so it needs all optional dependencies (even runtime ones), needs pytest, things for lint, pre-commit, coverage... Whereas docs only needs to execute 4 notebooks, for everything else importing the objects is enough, and it needs sphinx and all the extensions which explode in dependencies very quickly, myst-nb needs myst-parser and jupyter to execute and render notebooks for example. |
pymc/sampling_jax.py
Outdated
@@ -457,7 +462,7 @@ def sample_numpyro_nuts( | |||
|
|||
if random_seed is None: | |||
random_seed = model.rng_seeder.randint( | |||
2**30, dtype=np.int64, size=chains if chains > 1 else None | |||
2 ** 30, dtype=np.int64, size=chains if chains > 1 else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need to update pre-commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do I do that?
I don't know what's going on with the tests failing now, but this runs locally for me. I've rebased from main, but would like this go in sooner since it's a larger contribution. Postponing adding progress bar support until callback support has been included within |
@ricardoV94 any idea why testing is timing out? |
Nope |
Then I'm a bit out of ideas how to push this PR forward. |
@ricardoV94 looks like the tests pass now! |
Great! I manually made the docs rebuild to see if the failure was just flaky |
Looks like it does! |
Looks like we’re good to go! |
@ricardoV94 is this ready to merge? |
Thanks @zaxtax, awesome to have this in! |
One thing we forgot are release notes. |
I haven't done that for a lot of my `sampling_jax` contributions
…On Fri, 11 Mar 2022, 21:20 Thomas Wiecki, ***@***.***> wrote:
One thing we forgot are release notes.
—
Reply to this email directly, view it on GitHub
<#5477 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCULYOC4ZCZ7XE62YSSTU7OTKJANCNFSM5OQKUNPQ>
.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Yeah we were in this regime where sampling_jax is still "experimental" but now with it showing on the docs and all the work we have been doing is getting more and more official. So we should try to remember and mention changes in release notes, specially these large ones! We definitely need one for this |
This PR adds blackjax support into
sampling_jax
It's not perfect (progress bar support and sampling stats are not fully supported in blackjax). But it works with all transformed and deterministic variables.