Skip to content

Commit 5e7e389

Browse files
martiningramMartin Ingram
and
Martin Ingram
authored
Add SpecifyShape decorator to jax_funcify_Assert (#6062)
* Ignore SpecifyShape nodes in JAX mode Co-authored-by: Martin Ingram <[email protected]>
1 parent 9822ce5 commit 5e7e389

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

pymc/sampling_jax.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from aesara.link.jax.dispatch import jax_funcify
2828
from aesara.raise_op import Assert
2929
from aesara.tensor import TensorVariable
30+
from aesara.tensor.shape import SpecifyShape
3031
from arviz.data.base import make_attrs
3132

3233
from pymc import Model, modelcontext
@@ -38,6 +39,7 @@
3839

3940
@jax_funcify.register(Assert)
4041
@jax_funcify.register(CheckParameterValue)
42+
@jax_funcify.register(SpecifyShape)
4143
def jax_funcify_Assert(op, **kwargs):
4244
# Jax does not allow assert whose values aren't known during JIT compilation
4345
# within it's JIT-ed code. Hence we need to make a simple pass through

0 commit comments

Comments
 (0)