Skip to content

Warn on divergences after sampling with JAX #7051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Dec 6, 2023

What is this PR about?
Closes #7041

This is a draft to add convergence checks to the JAX samplers.

Right now I'm just calling run_convergence_checks after sampling. It might be nice to instead wrap the JAX returns in _sample_return, but this was easier. Feedback requested.

Checklist

Major / Breaking Changes

  • The blackjax sampler should now issue a warning on divergences

New features

  • None

Bugfixes

  • None

Documentation

  • None

Maintenance

  • None

📚 Documentation preview 📚: https://pymc--7051.org.readthedocs.build/en/7051/

Sorry, something went wrong.

@jessegrabowski jessegrabowski changed the title call convergence_check after sampling with numpyro Warn on divergences after sampling with JAX Dec 6, 2023
Copy link

codecov bot commented Dec 6, 2023

Codecov Report

Merging #7051 (e28c71e) into main (005ba5f) will increase coverage by 0.00%.
Report is 1 commits behind head on main.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #7051   +/-   ##
=======================================
  Coverage   92.16%   92.16%           
=======================================
  Files         101      101           
  Lines       16827    16831    +4     
=======================================
+ Hits        15509    15513    +4     
  Misses       1318     1318           
Files Coverage Δ
pymc/sampling/mcmc.py 87.79% <100.00%> (+0.10%) ⬆️

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 7, 2023

Looks fine, but why not in the sample_numpyro_nuts function itself? I am thinking of that because some users use that directly and also we already re-implement most of the logic of _sample_return there?

@jessegrabowski
Copy link
Member Author

I whipped up this as a live example during the sprint. I'll go back and do a better job following your suggestion @ricardoV94

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Report divergences from JAX samplers
3 participants