Skip to content

Adding NUTS sampler from blackjax to sampling_jax #5477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Mar 11, 2022

Conversation

zaxtax
Copy link
Contributor

@zaxtax zaxtax commented Feb 16, 2022

This PR adds blackjax support into sampling_jax

It's not perfect (progress bar support and sampling stats are not fully supported in blackjax). But it works with all transformed and deterministic variables.

@zaxtax
Copy link
Contributor Author

zaxtax commented Feb 16, 2022

Although not mentioned @junpenglao and @rlouf helped a lot with getting this to work from the blackjax side.

@codecov
Copy link

codecov bot commented Feb 16, 2022

Codecov Report

Merging #5477 (d5fcbb0) into main (44c5495) will increase coverage by 0.05%.
The diff coverage is 94.73%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5477      +/-   ##
==========================================
+ Coverage   87.59%   87.64%   +0.05%     
==========================================
  Files          76       76              
  Lines       13694    13765      +71     
==========================================
+ Hits        11995    12065      +70     
- Misses       1699     1700       +1     
Impacted Files Coverage Δ
pymc/sampling_jax.py 96.96% <94.73%> (-1.46%) ⬇️
pymc/parallel_sampling.py 87.70% <0.00%> (+0.99%) ⬆️

@rlouf
Copy link
Contributor

rlouf commented Feb 16, 2022

Great! We can add a progress with callbacks, as is done in MCX. Afair the overhead it introduces is very small.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. I'll have another look later, but left a comment above

@twiecki
Copy link
Member

twiecki commented Feb 16, 2022

This is awesome, thanks @zaxtax!
Closes #5454

@ricardoV94 ricardoV94 added the jax label Feb 16, 2022
model_logpt = model.logpt()
if not negative_logp:
model_logpt = -model_logpt
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt])

def logp_fn_wrap(x):
# NumPyro expects a scalar potential with the opposite sign of model.logpt
res = logp_fn(*x)[0]
Copy link
Member

@ricardoV94 ricardoV94 Feb 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's return this directly, instead of saving it intermediately in res

return inference_loop(seed, last_state)


def sample_blackjax_nuts(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function needs to be documented and added to the API docs (where the numpyro one might also be missing?). It probably makes more sense to add in the first block of https://github.com/pymc-devs/pymc/blob/main/docs/source/api/samplers.rst.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call.

Copy link
Contributor Author

@zaxtax zaxtax Feb 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit rusty with sphinx, where do I specify that the functions are in sampling_jax

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by that point currentmodule is set to pymc, so sphinx assumes functions live there. If they are in pymc.sampling_jax.<function> you'll need to use sampling_jax.<function> as the pymc is assumed because of the currentmodule use

observed_data=find_observations(model),
coords=coords,
dims=dims,
attrs={"sampling_time": (tic3 - tic2).total_seconds()},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also add pymc and blackjax versions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh is there an arviz convention for when there are multiple libraries? This came up in earlier discussions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not that I'm aware of 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to put in blackjax and numpyro as the inference library, and add pymc_version as another field.

@@ -30,6 +30,8 @@ HMC family

NUTS
HamiltonianMC
sampling_jax.sample_blackjax_nuts
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they should be in the first block (or in a different subsection), IIUC, both are nuts only even if blackjax has other samplers, but they are not step methods that can be used as inputs to pm.sample

@@ -188,6 +188,45 @@ def sample_blackjax_nuts(
chain_method="parallel",
idata_kwargs=None,
):
"""Draw samples from the posterior using the NUTS method from the blackjax library.

Parameters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update with the advise on https://pymc-data-umbrella.xyz/en/latest/sprint/docstring_tutorial.html#edit-the-docstring and I'll take a 2nd look later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I must be missing something but I think I'm style compliant. I actually copied the docstrings from sample as my starting point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically none of the docstrings are currently style compliant, and most aren't rendered correctly either because of that. We have an issue open for this: #5459. I can give a quick go at one of the docstrings. I think you'll also need to rebase on main for the docs preview to work and CI to pass.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some comments, but they are not exaustive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok it was how I'm specifying defaults and optional.

Comment on lines 197 to 199
draws : int
The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded
by default.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
draws : int
The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded
by default.
draws : int, default 1000
The number of samples to draw. The number of tuned samples are discarded
by default.

Comment on lines 211 to 212
model : Model (optional if in ``with`` context)
Model to sample from. The model needs to have free random variables.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model : Model (optional if in ``with`` context)
Model to sample from. The model needs to have free random variables.
model : Model, optional
Model to sample from. The model needs to have free random variables. When inside a ``with`` model
context, it defaults to that model, otherwise the model must be passed explicitly.

Comment on lines 213 to 214
var_names : Iterable[str]
Names of variables for which to compute the posterior samples.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
var_names : Iterable[str]
Names of variables for which to compute the posterior samples.
var_names : iterable of str, optional
Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior

Model to sample from. The model needs to have free random variables.
var_names : Iterable[str]
Names of variables for which to compute the posterior samples.
progress_bar : bool, optional default=True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
progress_bar : bool, optional default=True
progress_bar : bool, default True

Specify how samples should be drawn. The choices include "parallel", and
"vectorized". Defaults to "parallel".
idata_kwargs : dict, optional
Keyword arguments for :func:`pymc.to_inference_data`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Keyword arguments for :func:`pymc.to_inference_data`
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value for the ``log_likelihood``
key to indicate that the pointwise log likelihood should not be included in the returned object.

Comment on lines 228 to 229
trace : arviz.InferenceData
ArviZ ``InferenceData`` object that contains the samples.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
trace : arviz.InferenceData
ArviZ ``InferenceData`` object that contains the samples.
InferenceData
ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and
pointwise log likeihood values (unless skipped with ``idata_kwargs``).

Comment on lines 200 to 201
tune : int
Number of iterations to tune, defaults to 1000. Samplers adjust the step sizes, scalings or
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tune : int
Number of iterations to tune, defaults to 1000. Samplers adjust the step sizes, scalings or
tune : int, default 1000
Number of iterations to tune. Samplers adjust the step sizes, scalings or

@zaxtax zaxtax force-pushed the adding-blackjax-support branch from e1a636f to 8781b32 Compare February 16, 2022 20:35
@zaxtax zaxtax force-pushed the adding-blackjax-support branch from 8781b32 to 8a27c1f Compare February 16, 2022 20:38
@OriolAbril
Copy link
Member

The functions are still missing from docs (https://pymc--5477.org.readthedocs.build/en/5477/api/samplers.html) due to this issue:

WARNING: [autosummary] failed to import pymc.sampling_jax.sample_blackjax_nuts.
Possible hints:
* AttributeError: module 'pymc' has no attribute 'sampling_jax'
* ImportError: no module named pymc.sampling_jax
* ModuleNotFoundError: No module named 'jax'
WARNING: [autosummary] failed to import pymc.sampling_jax.sample_numpyro_nuts.
Possible hints:
* AttributeError: module 'pymc' has no attribute 'sampling_jax'
* ImportError: no module named pymc.sampling_jax
* ModuleNotFoundError: No module named 'jax'

what do we think is a better option, should jax be imported in a try except or installed in the docs env?

@zaxtax
Copy link
Contributor Author

zaxtax commented Feb 16, 2022

The functions are still missing from docs (https://pymc--5477.org.readthedocs.build/en/5477/api/samplers.html) due to this issue:

WARNING: [autosummary] failed to import pymc.sampling_jax.sample_blackjax_nuts.
Possible hints:
* AttributeError: module 'pymc' has no attribute 'sampling_jax'
* ImportError: no module named pymc.sampling_jax
* ModuleNotFoundError: No module named 'jax'
WARNING: [autosummary] failed to import pymc.sampling_jax.sample_numpyro_nuts.
Possible hints:
* AttributeError: module 'pymc' has no attribute 'sampling_jax'
* ImportError: no module named pymc.sampling_jax
* ModuleNotFoundError: No module named 'jax'

what do we think is a better option, should jax be imported in a try except or installed in the docs env?

I'm already installing jax in the other test environment. But I don't maybe there should be a try-except in the module itself. Would the latter suppress this error?

@OriolAbril
Copy link
Member

OriolAbril commented Feb 16, 2022

yep, both should fix the issue. Sphinx imports the objects to be documented in order to extract the __doc__ attributes, but it doesn't use them. The problem here is that it errors out before being able to import the two functions due to the jax import. I don't know the best way to go about it because I don't really understand how dependencies+dev+test+optional dependencies are organized or how problematic it would be to modify this import (I haven't used jax samplers yet)

@zaxtax
Copy link
Contributor Author

zaxtax commented Feb 16, 2022

yep, both should fix the issue. Sphinx imports the objects to be documented in order to extract the __doc__ attributes, but it doesn't use them. The problem here is that it errors out before being able to import the two functions due to the jax import. I don't know the best way to go about it because I don't really understand how dependencies+dev+test+optional dependencies are organized or how problematic it would be to modify this import (I haven't used jax samplers yet)

As the name implies, jax is required to really use anything in this module. But I think we can suppress some of this stuff.

@twiecki
Copy link
Member

twiecki commented Feb 19, 2022 via email

@OriolAbril
Copy link
Member

Aren't the two almost identical except for compilers and fixing blas?

assuming this is in reference to test and docs environments. Both have the common step of installing pymc but that is the only thing they have in common. Test then needs to execute and test virtually all the functions so it needs all optional dependencies (even runtime ones), needs pytest, things for lint, pre-commit, coverage... Whereas docs only needs to execute 4 notebooks, for everything else importing the objects is enough, and it needs sphinx and all the extensions which explode in dependencies very quickly, myst-nb needs myst-parser and jupyter to execute and render notebooks for example.

@@ -457,7 +462,7 @@ def sample_numpyro_nuts(

if random_seed is None:
random_seed = model.rng_seeder.randint(
2**30, dtype=np.int64, size=chains if chains > 1 else None
2 ** 30, dtype=np.int64, size=chains if chains > 1 else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to update pre-commit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do I do that?

@zaxtax
Copy link
Contributor Author

zaxtax commented Feb 28, 2022

I don't know what's going on with the tests failing now, but this runs locally for me.

I've rebased from main, but would like this go in sooner since it's a larger contribution.

Postponing adding progress bar support until callback support has been included within blackjax

@ericmjl
Copy link
Member

ericmjl commented Feb 28, 2022

@zaxtax curious to hear, in reading the logs, I see a "core dumped" error showing up. Could the issue be RAM usage?

The GH Actions runners have ~7GB of RAM, I think, I read here.

If I'm wrong, please feel free to correct me!

@twiecki twiecki requested a review from ricardoV94 February 28, 2022 16:20
@zaxtax
Copy link
Contributor Author

zaxtax commented Mar 3, 2022

@ricardoV94 any idea why testing is timing out?

@ricardoV94
Copy link
Member

@ricardoV94 any idea why testing is timing out?

Nope

@zaxtax
Copy link
Contributor Author

zaxtax commented Mar 3, 2022

@ricardoV94 any idea why testing is timing out?

Nope

Then I'm a bit out of ideas how to push this PR forward.

@ricardoV94 ricardoV94 added this to the v4.0.0b4 milestone Mar 7, 2022
@zaxtax
Copy link
Contributor Author

zaxtax commented Mar 10, 2022

@ricardoV94 looks like the tests pass now!

@ricardoV94
Copy link
Member

@ricardoV94 looks like the tests pass now!

Great! I manually made the docs rebuild to see if the failure was just flaky

@zaxtax
Copy link
Contributor Author

zaxtax commented Mar 10, 2022

@ricardoV94 looks like the tests pass now!

Great! I manually made the docs rebuild to see if the failure was just flaky

Looks like it does!

@ericmjl
Copy link
Member

ericmjl commented Mar 10, 2022

Looks like we’re good to go!

@zaxtax
Copy link
Contributor Author

zaxtax commented Mar 11, 2022

@ricardoV94 is this ready to merge?

@twiecki twiecki merged commit b799547 into pymc-devs:main Mar 11, 2022
@twiecki
Copy link
Member

twiecki commented Mar 11, 2022

Thanks @zaxtax, awesome to have this in!

@twiecki
Copy link
Member

twiecki commented Mar 11, 2022

One thing we forgot are release notes.

@zaxtax
Copy link
Contributor Author

zaxtax commented Mar 11, 2022 via email

@ricardoV94
Copy link
Member

ricardoV94 commented Mar 12, 2022

I haven't done that for a lot of my sampling_jax contributions

On Fri, 11 Mar 2022, 21:20 Thomas Wiecki, @.> wrote: One thing we forgot are release notes. — Reply to this email directly, view it on GitHub <#5477 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACCULYOC4ZCZ7XE62YSSTU7OTKJANCNFSM5OQKUNPQ . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub. You are receiving this because you were mentioned.Message ID: @.>

Yeah we were in this regime where sampling_jax is still "experimental" but now with it showing on the docs and all the work we have been doing is getting more and more official. So we should try to remember and mention changes in release notes, specially these large ones!

We definitely need one for this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants