Skip to content

ENH: cache helper functions #308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 21, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 108 additions & 84 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import math
import sys
import warnings
from collections.abc import Collection
from collections.abc import Collection, Hashable
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -61,23 +62,37 @@
_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"})


@lru_cache(100)
def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool:
try:
mod = sys.modules[modname]
except KeyError:
return False
parent_cls = getattr(mod, clsname)
return issubclass(cls, parent_cls)


def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
"""Return True if `x` is a zero-gradient array.

These arrays are a design quirk of Jax that may one day be removed.
See https://github.com/google/jax/issues/20620.
"""
if "numpy" not in sys.modules or "jax" not in sys.modules:
# Fast exit
try:
dtype = x.dtype # type: ignore[attr-defined]
except AttributeError:
return False
cls = cast(Hashable, type(dtype))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't say I'm happy to see cryptic things from typing doing something at runtime, but OK, am ready to believe it's somehow useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast is a noop at runtime.

if not _issubclass_fast(cls, "numpy.dtypes", "VoidDType"):
return False

import jax
import numpy as np
if "jax" not in sys.modules:
return False

jax_float0 = cast("np.dtype[np.void]", jax.float0)
return (
isinstance(x, np.ndarray)
and cast("npt.NDArray[np.void]", x).dtype == jax_float0
)
import jax
# jax.float0 is a np.dtype([('float0', 'V')])
return dtype == jax.float0


def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
Expand All @@ -101,15 +116,12 @@ def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing NumPy if it isn't already
if "numpy" not in sys.modules:
return False

import numpy as np

# TODO: Should we reject ndarray subclasses?
return (isinstance(x, (np.ndarray, np.generic))
and not _is_jax_zero_gradient_array(x)) # pyright: ignore[reportUnknownArgumentType] # fmt: skip
cls = cast(Hashable, type(x))
return (
_issubclass_fast(cls, "numpy", "ndarray")
or _issubclass_fast(cls, "numpy", "generic")
) and not _is_jax_zero_gradient_array(x)


def is_cupy_array(x: object) -> bool:
Expand All @@ -133,14 +145,8 @@ def is_cupy_array(x: object) -> bool:
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing CuPy if it isn't already
if "cupy" not in sys.modules:
return False

import cupy as cp # pyright: ignore[reportMissingTypeStubs]

# TODO: Should we reject ndarray subclasses?
return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType]
cls = cast(Hashable, type(x))
return _issubclass_fast(cls, "cupy", "ndarray")


def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
Expand All @@ -161,14 +167,8 @@ def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing torch if it isn't already
if "torch" not in sys.modules:
return False

import torch

# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)
cls = cast(Hashable, type(x))
return _issubclass_fast(cls, "torch", "Tensor")


def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
Expand All @@ -190,13 +190,8 @@ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing torch if it isn't already
if "ndonnx" not in sys.modules:
return False

import ndonnx as ndx

return isinstance(x, ndx.Array)
cls = cast(Hashable, type(x))
return _issubclass_fast(cls, "ndonnx", "Array")


def is_dask_array(x: object) -> TypeIs[da.Array]:
Expand All @@ -218,13 +213,8 @@ def is_dask_array(x: object) -> TypeIs[da.Array]:
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing dask if it isn't already
if "dask.array" not in sys.modules:
return False

import dask.array

return isinstance(x, dask.array.Array)
cls = cast(Hashable, type(x))
return _issubclass_fast(cls, "dask.array", "Array")


def is_jax_array(x: object) -> TypeIs[jax.Array]:
Expand All @@ -247,13 +237,8 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
is_dask_array
is_pydata_sparse_array
"""
# Avoid importing jax if it isn't already
if "jax" not in sys.modules:
return False

import jax

return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
cls = cast(Hashable, type(x))
return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x)


def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
Expand All @@ -276,14 +261,9 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
is_dask_array
is_jax_array
"""
# Avoid importing jax if it isn't already
if "sparse" not in sys.modules:
return False

import sparse # pyright: ignore[reportMissingTypeStubs]

# TODO: Account for other backends.
return isinstance(x, sparse.SparseArray)
cls = cast(Hashable, type(x))
return _issubclass_fast(cls, "sparse", "SparseArray")


def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType]
Expand All @@ -302,13 +282,23 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo
is_jax_array
"""
return (
is_numpy_array(x)
or is_cupy_array(x)
or is_torch_array(x)
or is_dask_array(x)
or is_jax_array(x)
or is_pydata_sparse_array(x)
or hasattr(x, "__array_namespace__")
hasattr(x, '__array_namespace__')
or _is_array_api_cls(cast(Hashable, type(x)))
)


@lru_cache(100)
def _is_array_api_cls(cls: type) -> bool:
return (
# TODO: drop support for numpy<2 which didn't have __array_namespace__
_issubclass_fast(cls, "numpy", "ndarray")
or _issubclass_fast(cls, "numpy", "generic")
or _issubclass_fast(cls, "cupy", "ndarray")
or _issubclass_fast(cls, "torch", "Tensor")
or _issubclass_fast(cls, "dask.array", "Array")
or _issubclass_fast(cls, "sparse", "SparseArray")
# TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
or _issubclass_fast(cls, "jax", "Array")
)


Expand All @@ -317,6 +307,7 @@ def _compat_module_name() -> str:
return __name__.removesuffix(".common._helpers")


@lru_cache(100)
def is_numpy_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a NumPy namespace.
Expand All @@ -338,6 +329,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"}


@lru_cache(100)
def is_cupy_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a CuPy namespace.
Expand All @@ -359,6 +351,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"}


@lru_cache(100)
def is_torch_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a PyTorch namespace.
Expand Down Expand Up @@ -399,6 +392,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
return xp.__name__ == "ndonnx"


@lru_cache(100)
def is_dask_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a Dask namespace.
Expand Down Expand Up @@ -939,6 +933,19 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
return None if math.isnan(out) else out


@lru_cache(100)
def _is_writeable_cls(cls: type) -> bool | None:
if (
_issubclass_fast(cls, "numpy", "generic")
or _issubclass_fast(cls, "jax", "Array")
or _issubclass_fast(cls, "sparse", "SparseArray")
):
return False
if _is_array_api_cls(cls):
return True
return None


def is_writeable_array(x: object) -> bool:
"""
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
Expand All @@ -949,11 +956,32 @@ def is_writeable_array(x: object) -> bool:
As there is no standard way to check if an array is writeable without actually
writing to it, this function blindly returns True for all unknown array types.
"""
if is_numpy_array(x):
return x.flags.writeable
if is_jax_array(x) or is_pydata_sparse_array(x):
cls = cast(Hashable, type(x))
if _issubclass_fast(cls, "numpy", "ndarray"):
return cast("npt.NDArray", x).flags.writeable
res = _is_writeable_cls(cls)
if res is not None:
return res
return hasattr(x, '__array_namespace__')


@lru_cache(100)
def _is_lazy_cls(cls: type) -> bool | None:
if (
_issubclass_fast(cls, "numpy", "ndarray")
or _issubclass_fast(cls, "numpy", "generic")
or _issubclass_fast(cls, "cupy", "ndarray")
or _issubclass_fast(cls, "torch", "Tensor")
or _issubclass_fast(cls, "sparse", "SparseArray")
):
return False
return is_array_api_obj(x)
if (
_issubclass_fast(cls, "jax", "Array")
or _issubclass_fast(cls, "dask.array", "Array")
or _issubclass_fast(cls, "ndonnx", "Array")
):
return True
return None


def is_lazy_array(x: object) -> bool:
Expand All @@ -969,14 +997,6 @@ def is_lazy_array(x: object) -> bool:
This function errs on the side of caution for array types that may or may not be
lazy, e.g. JAX arrays, by always returning True for them.
"""
if (
is_numpy_array(x)
or is_cupy_array(x)
or is_torch_array(x)
or is_pydata_sparse_array(x)
):
return False

# **JAX note:** while it is possible to determine if you're inside or outside
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
# as we do below for unknown arrays, this is not recommended by JAX best practices.
Expand All @@ -986,10 +1006,14 @@ def is_lazy_array(x: object) -> bool:
# compatibility, is highly detrimental to performance as the whole graph will end
# up being computed multiple times.

if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
return True
# Note: skipping reclassification of JAX zero gradient arrays, as one will
# exclusively get them once they leave a jax.grad JIT context.
cls = cast(Hashable, type(x))
res = _is_lazy_cls(cls)
if res is not None:
return res

if not is_array_api_obj(x):
if not hasattr(x, "__array_namespace__"):
return False

# Unknown Array API compatible object. Note that this test may have dire consequences
Expand Down Expand Up @@ -1042,7 +1066,7 @@ def is_lazy_array(x: object) -> bool:
"to_device",
]

_all_ignore = ["sys", "math", "inspect", "warnings"]
_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings']

def __dir__() -> list[str]:
return __all__
Loading