Skip to content

ENH: JAX OP MatrixIsPositiveDefinite #6849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
juanitorduz opened this issue Aug 7, 2023 · 6 comments · Fixed by #6853
Closed

ENH: JAX OP MatrixIsPositiveDefinite #6849

juanitorduz opened this issue Aug 7, 2023 · 6 comments · Fixed by #6853

Comments

@juanitorduz
Copy link
Contributor

juanitorduz commented Aug 7, 2023

Hi! I tried to replicate the Correlated priors with LKJCorr from https://tomicapretto.github.io/posts/2022-06-12_lkj-prior/#model-4-correlated-priors-with-lkjcorr.-replicate-rstanarm-prior using the numpyro sampler and got the error:

Before

`NotImplementedError: No JAX conversion for the given `Op`: MatrixIsPositiveDefinite`. 

So it would be great to add this operator.

Is this something a newbie like me could try?

After

We should be able to run the model above wit the JAX sampler from `numpyro`.
@ricardoV94
Copy link
Member

ricardoV94 commented Aug 7, 2023

That seems to be an Op implemented in PyMC, not in PyTensor, so any dispatch should be implemented there.

def posdef(AA):
try:
linalg.cholesky(AA)
return 1
except linalg.LinAlgError:
return 0
class PosDefMatrix(Op):
"""
Check if input is positive definite. Input should be a square matrix.
"""
# Properties attribute
__props__ = ()
# Compulsory if itypes and otypes are not defined
def make_node(self, x):
x = pt.as_tensor_variable(x)
assert x.ndim == 2
o = TensorType(dtype="int8", shape=[])()
return Apply(self, [x], [o])
# Python implementation:
def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
try:
z[0] = np.array(posdef(x), dtype="int8")
except Exception:
pm._log.exception("Failed to check if %s positive definite", x)
raise
def infer_shape(self, fgraph, node, shapes):
return [[]]
def grad(self, inp, grads):
(x,) = inp
return [x.zeros_like(pytensor.config.floatX)]
def __str__(self):
return "MatrixIsPositiveDefinite"
matrix_pos_def = PosDefMatrix()

That code relies on scipy to raise a LinalgError when calling cholesky with a non-semiposdef matrix. A jitted cholesky in JAX doesn't raise but returns nan. I don't know if this is guaranteed or I was lucky with the values I tested. If it is guaranteed, a reasonable dispatch might be any(isnan(linalg.cholesky(x))) (with JAX Op's)

Otherwise ignoring like we do for Assert may be the best option: https://github.com/pymc-devs/pytensor/blob/4235ccc3f4243c5179178a206c15d84c4cda2e79/pytensor/link/jax/dispatch/basic.py#L73-L84

@ricardoV94 ricardoV94 transferred this issue from pymc-devs/pytensor Aug 7, 2023
@ricardoV94 ricardoV94 added the jax label Aug 7, 2023
@juanitorduz
Copy link
Contributor Author

ok! so if I understand correctly we could either use

import  jax.numpy as jnp

jnp.any(jnp.isnan(jnp.linalg.cholesky(x)))

or

Otherwise ignoring like we do for Assert may be the best option: https://github.com/pymc-devs/pytensor/blob/4235ccc3f4243c5179178a206c15d84c4cda2e79/pytensor/link/jax/dispatch/basic.py#L73-L84

Can you clarify the second option (Otherwise ignoring like we do for Assert)? I do not get it 🙈

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 8, 2023

Can you clarify the second option (Otherwise ignoring like we do for Assert)? I do not get it 🙈

Some Ops like Assert ere not really supported by JAX so our dispatch is basically a pass through lambda x : x, ignoring the checks

I would try option #1 first

@juanitorduz
Copy link
Contributor Author

I want to try this one out! Just to get started... where should I add the JAX code

import  jax.numpy as jnp

jnp.any(jnp.isnan(jnp.linalg.cholesky(x)))

?

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 11, 2023

Check this guide (for pytensor): https://pytensor.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html
The last step (tests) is pytensor specific. We will test it more directly in PyMC because we don't have access to those test helpers from here.

Then the code would probably go here: https://github.com/pymc-devs/pymc/blob/main/pymc/sampling/jax.py

And the tests in the analogous test file.

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 11, 2023

While we are at it we can remove the SpecifyShape from here:

@jax_funcify.register(Assert)
@jax_funcify.register(CheckParameterValue)
@jax_funcify.register(SpecifyShape)
def jax_funcify_Assert(op, **kwargs):
# Jax does not allow assert whose values aren't known during JIT compilation
# within it's JIT-ed code. Hence we need to make a simple pass through
# version of the Assert Op.
# https://github.com/google/jax/issues/2273#issuecomment-589098722
def assert_fn(value, *inps):
return value
return assert_fn

Because that's actually supported and well implemented in PyTensor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants