From 632f08119d4c5da923d735162764e3388c058bda Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 15 Apr 2025 17:27:26 +0100 Subject: [PATCH 1/6] ENH: cache helper functions --- array_api_compat/common/_helpers.py | 120 +++++++++++++--------------- 1 file changed, 55 insertions(+), 65 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 67c619b8..7c46676f 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -11,24 +11,42 @@ import math import inspect import warnings +from functools import cache from typing import Optional, Union, Any from ._typing import Array, Device, Namespace -def _is_jax_zero_gradient_array(x: object) -> bool: +@cache +def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool: + try: + mod = sys.modules[modname] + except KeyError: + return False + parent_cls = getattr(mod, clsname) + return issubclass(cls, parent_cls) + + +def _is_jax_zero_gradient_array(x: Array) -> bool: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. See https://github.com/google/jax/issues/20620. """ - if 'numpy' not in sys.modules or 'jax' not in sys.modules: + # Fast exit + try: + dtype = x.dtype + except AttributeError: + return False + if not _issubclass_fast(type(dtype), "numpy.dtypes", "VoidDType"): return False - import numpy as np - import jax + if "jax" not in sys.modules: + return False - return isinstance(x, np.ndarray) and x.dtype == jax.float0 + import jax + # jax.float0 is a np.dtype([('float0', 'V')]) + return dtype == jax.float0 def is_numpy_array(x: object) -> bool: @@ -52,15 +70,12 @@ def is_numpy_array(x: object) -> bool: is_jax_array is_pydata_sparse_array """ - # Avoid importing NumPy if it isn't already - if 'numpy' not in sys.modules: - return False - - import numpy as np - # TODO: Should we reject ndarray subclasses? - return (isinstance(x, (np.ndarray, np.generic)) - and not _is_jax_zero_gradient_array(x)) + cls = type(x) + return ( + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + ) and not _is_jax_zero_gradient_array(x) def is_cupy_array(x: object) -> bool: @@ -84,14 +99,7 @@ def is_cupy_array(x: object) -> bool: is_jax_array is_pydata_sparse_array """ - # Avoid importing CuPy if it isn't already - if 'cupy' not in sys.modules: - return False - - import cupy as cp - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, cp.ndarray) + return _issubclass_fast(type(x), "cupy", "ndarray") def is_torch_array(x: object) -> bool: @@ -112,14 +120,7 @@ def is_torch_array(x: object) -> bool: is_jax_array is_pydata_sparse_array """ - # Avoid importing torch if it isn't already - if 'torch' not in sys.modules: - return False - - import torch - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, torch.Tensor) + return _issubclass_fast(type(x), "torch", "Tensor") def is_ndonnx_array(x: object) -> bool: @@ -141,13 +142,7 @@ def is_ndonnx_array(x: object) -> bool: is_jax_array is_pydata_sparse_array """ - # Avoid importing torch if it isn't already - if 'ndonnx' not in sys.modules: - return False - - import ndonnx as ndx - - return isinstance(x, ndx.Array) + return _issubclass_fast(type(x), "ndonnx", "Array") def is_dask_array(x: object) -> bool: @@ -169,13 +164,7 @@ def is_dask_array(x: object) -> bool: is_jax_array is_pydata_sparse_array """ - # Avoid importing dask if it isn't already - if 'dask.array' not in sys.modules: - return False - - import dask.array - - return isinstance(x, dask.array.Array) + return _issubclass_fast(type(x), "dask.array", "Array") def is_jax_array(x: object) -> bool: @@ -198,13 +187,7 @@ def is_jax_array(x: object) -> bool: is_dask_array is_pydata_sparse_array """ - # Avoid importing jax if it isn't already - if 'jax' not in sys.modules: - return False - - import jax - - return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) + return _issubclass_fast(type(x), "jax", "Array") or _is_jax_zero_gradient_array(x) def is_pydata_sparse_array(x) -> bool: @@ -227,14 +210,8 @@ def is_pydata_sparse_array(x) -> bool: is_dask_array is_jax_array """ - # Avoid importing jax if it isn't already - if 'sparse' not in sys.modules: - return False - - import sparse - # TODO: Account for other backends. - return isinstance(x, sparse.SparseArray) + return _issubclass_fast(type(x), "sparse", "SparseArray") def is_array_api_obj(x: object) -> bool: @@ -252,13 +229,22 @@ def is_array_api_obj(x: object) -> bool: is_dask_array is_jax_array """ - return is_numpy_array(x) \ - or is_cupy_array(x) \ - or is_torch_array(x) \ - or is_dask_array(x) \ - or is_jax_array(x) \ - or is_pydata_sparse_array(x) \ - or hasattr(x, '__array_namespace__') + return hasattr(x, '__array_namespace__') or _is_array_api_cls(type(x)) + + +@cache +def _is_array_api_cls(cls: type) -> bool: + return ( + # TODO: drop support for numpy<2 which didn't have __array_namespace__ + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "cupy", "ndarray") + or _issubclass_fast(cls, "torch", "Tensor") + or _issubclass_fast(cls, "dask.array", "Array") + or _issubclass_fast(cls, "sparse", "SparseArray") + # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__ + or _issubclass_fast(cls, "jax", "Array") + ) def _compat_module_name() -> str: @@ -266,6 +252,7 @@ def _compat_module_name() -> str: return __name__.removesuffix('.common._helpers') +@cache def is_numpy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a NumPy namespace. @@ -287,6 +274,7 @@ def is_numpy_namespace(xp: Namespace) -> bool: return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} +@cache def is_cupy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a CuPy namespace. @@ -308,6 +296,7 @@ def is_cupy_namespace(xp: Namespace) -> bool: return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} +@cache def is_torch_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a PyTorch namespace. @@ -348,6 +337,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool: return xp.__name__ == 'ndonnx' +@cache def is_dask_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a Dask namespace. @@ -952,4 +942,4 @@ def is_lazy_array(x: object) -> bool: "to_device", ] -_all_ignore = ['sys', 'math', 'inspect', 'warnings'] +_all_ignore = ['cache', 'sys', 'math', 'inspect', 'warnings'] From fc6b56bcafa3c179d8a7b0e8dbed2156fb27afa7 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 15 Apr 2025 18:18:18 +0100 Subject: [PATCH 2/6] is_lazy_array, is_writeable_array --- array_api_compat/common/_helpers.py | 57 ++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 7c46676f..282dee2e 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -839,6 +839,19 @@ def size(x: Array) -> int | None: return None if math.isnan(out) else out +@cache +def _is_writeable_cls(cls: type) -> bool | None: + if ( + _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "sparse", "SparseArray") + ): + return False + if _is_array_api_cls(cls): + return True + return None + + def is_writeable_array(x: object) -> bool: """ Return False if ``x.__setitem__`` is expected to raise; True otherwise. @@ -849,11 +862,32 @@ def is_writeable_array(x: object) -> bool: As there is no standard way to check if an array is writeable without actually writing to it, this function blindly returns True for all unknown array types. """ - if is_numpy_array(x): + cls = type(x) + if _issubclass_fast(cls, "numpy", "ndarray"): return x.flags.writeable - if is_jax_array(x) or is_pydata_sparse_array(x): + res = _is_writeable_cls(cls) + if res is not None: + return res + return hasattr(x, '__array_namespace__') + + +@cache +def _is_lazy_cls(cls: type) -> bool | None: + if ( + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "cupy", "ndarray") + or _issubclass_fast(cls, "torch", "Tensor") + or _issubclass_fast(cls, "sparse", "SparseArray") + ): return False - return is_array_api_obj(x) + if ( + _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "dask.array", "Array") + or _issubclass_fast(cls, "ndonnx", "Array") + ): + return True + return None def is_lazy_array(x: object) -> bool: @@ -869,14 +903,6 @@ def is_lazy_array(x: object) -> bool: This function errs on the side of caution for array types that may or may not be lazy, e.g. JAX arrays, by always returning True for them. """ - if ( - is_numpy_array(x) - or is_cupy_array(x) - or is_torch_array(x) - or is_pydata_sparse_array(x) - ): - return False - # **JAX note:** while it is possible to determine if you're inside or outside # jax.jit by testing the subclass of a jax.Array object, as well as testing bool() # as we do below for unknown arrays, this is not recommended by JAX best practices. @@ -886,10 +912,13 @@ def is_lazy_array(x: object) -> bool: # compatibility, is highly detrimental to performance as the whole graph will end # up being computed multiple times. - if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x): - return True + # Note: skipping reclassification of JAX zero gradient arrays, as one will + # exclusively get them once they leave a jax.grad JIT context. + res = _is_lazy_cls(type(x)) + if res is not None: + return res - if not is_array_api_obj(x): + if not hasattr(x, "__array_namespace__"): return False # Unknown Array API compatible object. Note that this test may have dire consequences From 931faae69cd67482ba8f8d5ccbd61316712184fc Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 18 Apr 2025 11:08:28 +0100 Subject: [PATCH 3/6] Merge --- array_api_compat/common/_helpers.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 305939d9..ae4ba6e3 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -13,6 +13,7 @@ import sys import warnings from collections.abc import Collection +from functools import lru_cache from typing import ( TYPE_CHECKING, Any, @@ -61,8 +62,7 @@ _API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"}) -def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: -@cache +@lru_cache(100) def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool: try: mod = sys.modules[modname] @@ -72,6 +72,7 @@ def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool: return issubclass(cls, parent_cls) +def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. @@ -276,7 +277,7 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo return hasattr(x, '__array_namespace__') or _is_array_api_cls(type(x)) -@cache +@lru_cache(100) def _is_array_api_cls(cls: type) -> bool: return ( # TODO: drop support for numpy<2 which didn't have __array_namespace__ @@ -296,7 +297,7 @@ def _compat_module_name() -> str: return __name__.removesuffix(".common._helpers") -@cache +@lru_cache(100) def is_numpy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a NumPy namespace. @@ -318,7 +319,7 @@ def is_numpy_namespace(xp: Namespace) -> bool: return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"} -@cache +@lru_cache(100) def is_cupy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a CuPy namespace. @@ -340,7 +341,7 @@ def is_cupy_namespace(xp: Namespace) -> bool: return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"} -@cache +@lru_cache(100) def is_torch_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a PyTorch namespace. @@ -381,7 +382,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool: return xp.__name__ == "ndonnx" -@cache +@lru_cache(100) def is_dask_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a Dask namespace. @@ -922,7 +923,7 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: return None if math.isnan(out) else out -@cache +@lru_cache(100) def _is_writeable_cls(cls: type) -> bool | None: if ( _issubclass_fast(cls, "numpy", "generic") @@ -954,7 +955,7 @@ def is_writeable_array(x: object) -> bool: return hasattr(x, '__array_namespace__') -@cache +@lru_cache(100) def _is_lazy_cls(cls: type) -> bool | None: if ( _issubclass_fast(cls, "numpy", "ndarray") @@ -1054,7 +1055,7 @@ def is_lazy_array(x: object) -> bool: "to_device", ] -_all_ignore = ['cache', 'sys', 'math', 'inspect', 'warnings'] +_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings'] def __dir__() -> list[str]: return __all__ From 6db8b11e4e9778ad3019bed4a2f6d136d60a0477 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 18 Apr 2025 11:11:43 +0100 Subject: [PATCH 4/6] wip --- array_api_compat/common/_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index ae4ba6e3..2ce95756 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -80,7 +80,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: """ # Fast exit try: - dtype = x.dtype + dtype = x.dtype # type: ignore[attr-defined] except AttributeError: return False if not _issubclass_fast(type(dtype), "numpy.dtypes", "VoidDType"): From 97f696799c61ad4294873cf8dc35d34e541508f1 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 18 Apr 2025 11:23:04 +0100 Subject: [PATCH 5/6] wip --- array_api_compat/common/_helpers.py | 37 +++++++++++++++++++---------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 2ce95756..f1739de3 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -12,7 +12,7 @@ import math import sys import warnings -from collections.abc import Collection +from collections.abc import Collection, Hashable from functools import lru_cache from typing import ( TYPE_CHECKING, @@ -83,7 +83,8 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: dtype = x.dtype # type: ignore[attr-defined] except AttributeError: return False - if not _issubclass_fast(type(dtype), "numpy.dtypes", "VoidDType"): + cls = cast(Hashable, type(dtype)) + if not _issubclass_fast(cls, "numpy.dtypes", "VoidDType"): return False if "jax" not in sys.modules: @@ -116,7 +117,7 @@ def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: is_pydata_sparse_array """ # TODO: Should we reject ndarray subclasses? - cls = type(x) + cls = cast(Hashable, type(x)) return ( _issubclass_fast(cls, "numpy", "ndarray") or _issubclass_fast(cls, "numpy", "generic") @@ -144,7 +145,8 @@ def is_cupy_array(x: object) -> bool: is_jax_array is_pydata_sparse_array """ - return _issubclass_fast(type(x), "cupy", "ndarray") + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "cupy", "ndarray") def is_torch_array(x: object) -> TypeIs[torch.Tensor]: @@ -165,7 +167,8 @@ def is_torch_array(x: object) -> TypeIs[torch.Tensor]: is_jax_array is_pydata_sparse_array """ - return _issubclass_fast(type(x), "torch", "Tensor") + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "torch", "Tensor") def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: @@ -187,7 +190,8 @@ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: is_jax_array is_pydata_sparse_array """ - return _issubclass_fast(type(x), "ndonnx", "Array") + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "ndonnx", "Array") def is_dask_array(x: object) -> TypeIs[da.Array]: @@ -209,7 +213,8 @@ def is_dask_array(x: object) -> TypeIs[da.Array]: is_jax_array is_pydata_sparse_array """ - return _issubclass_fast(type(x), "dask.array", "Array") + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "dask.array", "Array") def is_jax_array(x: object) -> TypeIs[jax.Array]: @@ -232,7 +237,8 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]: is_dask_array is_pydata_sparse_array """ - return _issubclass_fast(type(x), "jax", "Array") or _is_jax_zero_gradient_array(x) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x) def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: @@ -256,7 +262,8 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: is_jax_array """ # TODO: Account for other backends. - return _issubclass_fast(type(x), "sparse", "SparseArray") + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "sparse", "SparseArray") def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] @@ -274,7 +281,10 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo is_dask_array is_jax_array """ - return hasattr(x, '__array_namespace__') or _is_array_api_cls(type(x)) + return ( + hasattr(x, '__array_namespace__') + or _is_array_api_cls(cast(Hashable, type(x))) + ) @lru_cache(100) @@ -946,9 +956,9 @@ def is_writeable_array(x: object) -> bool: As there is no standard way to check if an array is writeable without actually writing to it, this function blindly returns True for all unknown array types. """ - cls = type(x) + cls = cast(Hashable, type(x)) if _issubclass_fast(cls, "numpy", "ndarray"): - return x.flags.writeable + return cast(npt.NDArray, x).flags.writeable res = _is_writeable_cls(cls) if res is not None: return res @@ -998,7 +1008,8 @@ def is_lazy_array(x: object) -> bool: # Note: skipping reclassification of JAX zero gradient arrays, as one will # exclusively get them once they leave a jax.grad JIT context. - res = _is_lazy_cls(type(x)) + cls = cast(Hashable, type(x)) + res = _is_lazy_cls(cls) if res is not None: return res From 969ae83c18ee5696ee094bc8c29f034a72461bb1 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 18 Apr 2025 11:24:07 +0100 Subject: [PATCH 6/6] fix --- array_api_compat/common/_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index f1739de3..d50e0d83 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -958,7 +958,7 @@ def is_writeable_array(x: object) -> bool: """ cls = cast(Hashable, type(x)) if _issubclass_fast(cls, "numpy", "ndarray"): - return cast(npt.NDArray, x).flags.writeable + return cast("npt.NDArray", x).flags.writeable res = _is_writeable_cls(cls) if res is not None: return res