diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index 99fcf65a..b856eb41 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -463,7 +463,7 @@ def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: # Notes ----- The `instances` iterable must yield at least the same number of elements as the ones - returned by ``pickle_without``, but the elements do not need to be the same objects + returned by ``pickle_flatten``, but the elements do not need to be the same objects or even the same types of objects. Excess elements, if any, will be left untouched. """ iters = iter(instances), iter(rest) @@ -540,6 +540,25 @@ def jax_autojit( See Also -------- jax.jit : JAX JIT compilation function. + + Notes + ----- + These are useful choices *for testing purposes only*, which is how this function is + intended to be used. The output of ``jax.jit`` is a C++ level callable, that + directly dispatches to the compiled kernel after the initial call. In comparison, + ``jax_autojit`` incurs a much higher dispatch time. + + Additionally, consider:: + + def f(x: Array, y: float, plus: bool) -> Array: + return x + y if plus else x - y + + j1 = jax.jit(f, static_argnames="plus") + j2 = jax_autojit(f) + + In the above example, ``j2`` requires a lot less setup to be tested effectively than + ``j1``, but on the flip side it means that it will be re-traced for every different + value of ``y``, which likely makes it not fit for purpose in production. """ import jax diff --git a/src/array_api_extra/testing.py b/src/array_api_extra/testing.py index c14e9a22..3979f9dd 100644 --- a/src/array_api_extra/testing.py +++ b/src/array_api_extra/testing.py @@ -96,6 +96,7 @@ def lazy_xp_function( # type: ignore[explicit-any] jax_jit : bool, optional Set to True to replace `func` with a smart variant of ``jax.jit(func)`` after calling the :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. + This is the default behaviour. Set to False if `func` is only compatible with eager (non-jitted) JAX. Unlike with vanilla ``jax.jit``, all arguments and return types that are not JAX