We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9822ce5 commit 5e7e389Copy full SHA for 5e7e389
pymc/sampling_jax.py
@@ -27,6 +27,7 @@
27
from aesara.link.jax.dispatch import jax_funcify
28
from aesara.raise_op import Assert
29
from aesara.tensor import TensorVariable
30
+from aesara.tensor.shape import SpecifyShape
31
from arviz.data.base import make_attrs
32
33
from pymc import Model, modelcontext
@@ -38,6 +39,7 @@
38
39
40
@jax_funcify.register(Assert)
41
@jax_funcify.register(CheckParameterValue)
42
+@jax_funcify.register(SpecifyShape)
43
def jax_funcify_Assert(op, **kwargs):
44
# Jax does not allow assert whose values aren't known during JIT compilation
45
# within it's JIT-ed code. Hence we need to make a simple pass through
0 commit comments