From 1140a1af296200bcb78e455a9e33ba2ea5bf82e9 Mon Sep 17 00:00:00 2001
From: Matt Haberland <mhaberla@calpoly.edu>
Date: Wed, 16 Apr 2025 10:49:29 -0700
Subject: [PATCH 1/6] WIP: xp_assert enhancements

---
 src/array_api_extra/_lib/_testing.py | 131 +++++++++++++++------------
 tests/test_testing.py                |  52 ++++++++++-
 2 files changed, 122 insertions(+), 61 deletions(-)

diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py
index 319297c8..36822892 100644
--- a/src/array_api_extra/_lib/_testing.py
+++ b/src/array_api_extra/_lib/_testing.py
@@ -9,6 +9,7 @@
 from types import ModuleType
 from typing import cast
 
+import numpy as np
 import pytest
 
 from ._utils._compat import (
@@ -16,6 +17,7 @@
     is_array_api_strict_namespace,
     is_cupy_namespace,
     is_dask_namespace,
+    is_numpy_namespace,
     is_pydata_sparse_namespace,
     is_torch_namespace,
 )
@@ -25,7 +27,11 @@
 
 
 def _check_ns_shape_dtype(
-    actual: Array, desired: Array
+    actual: Array,
+    desired: Array,
+    check_dtype: bool,
+    check_shape: bool,
+    check_scalar: bool,
 ) -> ModuleType:  # numpydoc ignore=RT03
     """
     Assert that namespace, shape and dtype of the two arrays match.
@@ -47,43 +53,64 @@ def _check_ns_shape_dtype(
     msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
     assert actual_xp == desired_xp, msg
 
-    actual_shape = actual.shape
-    desired_shape = desired.shape
-    if is_dask_namespace(desired_xp):
-        # Dask uses nan instead of None for unknown shapes
-        if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)):
-            actual_shape = actual.compute().shape  # type: ignore[attr-defined]  # pyright: ignore[reportAttributeAccessIssue]
-        if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)):
-            desired_shape = desired.compute().shape  # type: ignore[attr-defined]  # pyright: ignore[reportAttributeAccessIssue]
-
-    msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
-    assert actual_shape == desired_shape, msg
-
-    msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
-    assert actual.dtype == desired.dtype, msg
+    if check_shape:
+        actual_shape = actual.shape
+        desired_shape = desired.shape
+        if is_dask_namespace(desired_xp):
+            # Dask uses nan instead of None for unknown shapes
+            if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)):
+                actual_shape = actual.compute().shape  # type: ignore[attr-defined]  # pyright: ignore[reportAttributeAccessIssue]
+            if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)):
+                desired_shape = desired.compute().shape  # type: ignore[attr-defined]  # pyright: ignore[reportAttributeAccessIssue]
+
+        msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
+        assert actual_shape == desired_shape, msg
+
+    if check_dtype:
+        msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
+        assert actual.dtype == desired.dtype, msg
+
+    if is_numpy_namespace(actual_xp) and check_scalar:
+        # only NumPy distinguishes between scalars and arrays; we do if check_scalar.
+        _msg = (
+            "array-ness does not match:\n Actual: "
+            f"{type(actual)}\n Desired: {type(desired)}"
+        )
+        assert (np.isscalar(actual) and np.isscalar(desired)) or (
+            not np.isscalar(actual) and not np.isscalar(desired)
+        ), _msg
 
     return desired_xp
 
 
 def _prepare_for_test(array: Array, xp: ModuleType) -> Array:
     """
-    Ensure that the array can be compared with xp.testing or np.testing.
+    Ensure that the array can be compared with np.testing.
 
     This involves transferring it from GPU to CPU memory, densifying it, etc.
     """
     if is_torch_namespace(xp):
-        return array.cpu()  # type: ignore[attr-defined]  # pyright: ignore[reportAttributeAccessIssue]
+        return np.asarray(array.cpu())  # type: ignore[attr-defined, return-value]  # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportReturnType]
     if is_pydata_sparse_namespace(xp):
         return array.todense()  # type: ignore[attr-defined]  # pyright: ignore[reportAttributeAccessIssue]
     if is_array_api_strict_namespace(xp):
         # Note: we deliberately did not add a `.to_device` method in _typing.pyi
         # even if it is required by the standard as many backends don't support it
         return array.to_device(xp.Device("CPU_DEVICE"))  # type: ignore[attr-defined]  # pyright: ignore[reportAttributeAccessIssue]
-    # Note: nothing to do for CuPy, because it uses a bespoke test function
+    if is_cupy_namespace(xp):
+        return xp.asnumpy(array)
     return array
 
 
-def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
+def xp_assert_equal(
+    actual: Array,
+    desired: Array,
+    *,
+    err_msg: str = "",
+    check_dtype: bool = True,
+    check_shape: bool = True,
+    check_scalar: bool = False,
+) -> None:
     """
     Array-API compatible version of `np.testing.assert_array_equal`.
 
@@ -95,34 +122,21 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
         The expected array (typically hardcoded).
     err_msg : str, optional
         Error message to display on failure.
+    check_dtype, check_shape : bool, default: True
+        Whether to check agreement between actual and desired dtypes and shapes
+    check_scalar : bool, default: False
+        NumPy only: whether to check agreement between actual and desired types -
+        0d array vs scalar.
 
     See Also
     --------
     xp_assert_close : Similar function for inexact equality checks.
     numpy.testing.assert_array_equal : Similar function for NumPy arrays.
     """
-    xp = _check_ns_shape_dtype(actual, desired)
+    xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
     actual = _prepare_for_test(actual, xp)
     desired = _prepare_for_test(desired, xp)
-
-    if is_cupy_namespace(xp):
-        xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
-    elif is_torch_namespace(xp):
-        # PyTorch recommends using `rtol=0, atol=0` like this
-        # to test for exact equality
-        xp.testing.assert_close(
-            actual,
-            desired,
-            rtol=0,
-            atol=0,
-            equal_nan=True,
-            check_dtype=False,
-            msg=err_msg or None,
-        )
-    else:
-        import numpy as np  # pylint: disable=import-outside-toplevel
-
-        np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
+    np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
 
 
 def xp_assert_close(
@@ -132,6 +146,9 @@ def xp_assert_close(
     rtol: float | None = None,
     atol: float = 0,
     err_msg: str = "",
+    check_dtype: bool = True,
+    check_shape: bool = True,
+    check_scalar: bool = False,
 ) -> None:
     """
     Array-API compatible version of `np.testing.assert_allclose`.
@@ -148,6 +165,11 @@ def xp_assert_close(
         Absolute tolerance. Default: 0.
     err_msg : str, optional
         Error message to display on failure.
+    check_dtype, check_shape : bool, default: True
+        Whether to check agreement between actual and desired dtypes and shapes
+    check_scalar : bool, default: False
+        NumPy only: whether to check agreement between actual and desired types -
+        0d array vs scalar.
 
     See Also
     --------
@@ -159,7 +181,7 @@ def xp_assert_close(
     -----
     The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
     """
-    xp = _check_ns_shape_dtype(actual, desired)
+    xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
 
     floating = xp.isdtype(actual.dtype, ("real floating", "complex floating"))
     if rtol is None and floating:
@@ -173,26 +195,15 @@ def xp_assert_close(
     actual = _prepare_for_test(actual, xp)
     desired = _prepare_for_test(desired, xp)
 
-    if is_cupy_namespace(xp):
-        xp.testing.assert_allclose(
-            actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
-        )
-    elif is_torch_namespace(xp):
-        xp.testing.assert_close(
-            actual, desired, rtol=rtol, atol=atol, equal_nan=True, msg=err_msg or None
-        )
-    else:
-        import numpy as np  # pylint: disable=import-outside-toplevel
-
-        # JAX/Dask arrays work directly with `np.testing`
-        assert isinstance(rtol, float)
-        np.testing.assert_allclose(  # type: ignore[call-overload]  # pyright: ignore[reportCallIssue]
-            actual,  # pyright: ignore[reportArgumentType]
-            desired,  # pyright: ignore[reportArgumentType]
-            rtol=rtol,
-            atol=atol,
-            err_msg=err_msg,
-        )
+    # JAX/Dask arrays work directly with `np.testing`
+    assert isinstance(rtol, float)
+    np.testing.assert_allclose(  # type: ignore[call-overload]  # pyright: ignore[reportCallIssue]
+        actual,  # pyright: ignore[reportArgumentType]
+        desired,  # pyright: ignore[reportArgumentType]
+        rtol=rtol,
+        atol=atol,
+        err_msg=err_msg,
+    )
 
 
 def xfail(request: pytest.FixtureRequest, reason: str) -> None:
diff --git a/tests/test_testing.py b/tests/test_testing.py
index ff67121b..1f31d282 100644
--- a/tests/test_testing.py
+++ b/tests/test_testing.py
@@ -1,4 +1,5 @@
 from collections.abc import Callable
+from contextlib import nullcontext
 from types import ModuleType
 from typing import cast
 
@@ -24,7 +25,9 @@
         xp_assert_equal,
         pytest.param(
             xp_assert_close,
-            marks=pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype"),
+            marks=pytest.mark.xfail_xp_backend(
+                Backend.SPARSE, reason="no isdtype", strict=False
+            ),
         ),
     ],
 )
@@ -60,6 +63,53 @@ def test_assert_close_equal_namespace(xp: ModuleType, func: Callable[..., None])
         func(xp.asarray([0]), [0])
 
 
+@param_assert_equal_close
+@pytest.mark.parametrize("check_shape", [False, True])
+def test_assert_close_equal_shape(  # type: ignore[explicit-any]
+    xp: ModuleType,
+    func: Callable[..., None],
+    check_shape: bool,
+):
+    context = (
+        pytest.raises(AssertionError, match="shapes do not match")
+        if check_shape
+        else nullcontext()
+    )
+    with context:
+        func(xp.asarray([0, 0]), xp.asarray(0), check_shape=check_shape)
+
+
+@param_assert_equal_close
+@pytest.mark.parametrize("check_dtype", [False, True])
+def test_assert_close_equal_dtype(  # type: ignore[explicit-any]
+    xp: ModuleType,
+    func: Callable[..., None],
+    check_dtype: bool,
+):
+    context = (
+        pytest.raises(AssertionError, match="dtypes do not match")
+        if check_dtype
+        else nullcontext()
+    )
+    with context:
+        func(xp.asarray(0.0), xp.asarray(0), check_dtype=check_dtype)
+
+
+@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
+@pytest.mark.parametrize("check_scalar", [False, True])
+def test_assert_close_equal_scalar(  # type: ignore[explicit-any]
+    func: Callable[..., None],
+    check_scalar: bool,
+):
+    context = (
+        pytest.raises(AssertionError, match="array-ness does not match")
+        if check_scalar
+        else nullcontext()
+    )
+    with context:
+        func(np.asarray(0), np.asarray(0)[()], check_scalar=check_scalar)
+
+
 @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
 def test_assert_close_tolerance(xp: ModuleType):
     xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.03)

From d7f754931276c50969ebbf70bf404ac96216c5d4 Mon Sep 17 00:00:00 2001
From: Matt Haberland <mhaberla@calpoly.edu>
Date: Mon, 21 Apr 2025 10:47:54 -0700
Subject: [PATCH 2/6] ENH: add xp_assert_less

---
 src/array_api_extra/_lib/_testing.py | 40 ++++++++++++++++++++++---
 tests/test_testing.py                | 44 ++++++++++++++++++++--------
 2 files changed, 68 insertions(+), 16 deletions(-)

diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py
index 41eeddc7..c7f700ed 100644
--- a/src/array_api_extra/_lib/_testing.py
+++ b/src/array_api_extra/_lib/_testing.py
@@ -76,9 +76,7 @@ def _check_ns_shape_dtype(
             "array-ness does not match:\n Actual: "
             f"{type(actual)}\n Desired: {type(desired)}"
         )
-        assert (np.isscalar(actual) and np.isscalar(desired)) or (
-            not np.isscalar(actual) and not np.isscalar(desired)
-        ), _msg
+        assert np.isscalar(actual) == np.isscalar(desired), _msg
 
     return desired_xp
 
@@ -139,6 +137,41 @@ def xp_assert_equal(
     np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
 
 
+def xp_assert_less(
+    x: Array,
+    y: Array,
+    *,
+    err_msg: str = "",
+    check_dtype: bool = True,
+    check_shape: bool = True,
+    check_scalar: bool = False,
+) -> None:
+    """
+    Array-API compatible version of `np.testing.assert_array_less`.
+
+    Parameters
+    ----------
+    x, y : Array
+        The arrays to compare according to ``x < y`` (elementwise).
+    err_msg : str, optional
+        Error message to display on failure.
+    check_dtype, check_shape : bool, default: True
+        Whether to check agreement between actual and desired dtypes and shapes
+    check_scalar : bool, default: False
+        NumPy only: whether to check agreement between actual and desired types -
+        0d array vs scalar.
+
+    See Also
+    --------
+    xp_assert_close : Similar function for inexact equality checks.
+    numpy.testing.assert_array_equal : Similar function for NumPy arrays.
+    """
+    xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
+    x = _prepare_for_test(x, xp)
+    y = _prepare_for_test(y, xp)
+    np.testing.assert_array_less(x, y, err_msg=err_msg)  # type: ignore[call-overload]
+
+
 def xp_assert_close(
     actual: Array,
     desired: Array,
@@ -196,7 +229,6 @@ def xp_assert_close(
     desired = _prepare_for_test(desired, xp)
 
     # JAX/Dask arrays work directly with `np.testing`
-    assert isinstance(rtol, float)
     np.testing.assert_allclose(  # type: ignore[call-overload]  # pyright: ignore[reportCallIssue]
         actual,  # pyright: ignore[reportArgumentType]
         desired,  # pyright: ignore[reportArgumentType]
diff --git a/tests/test_testing.py b/tests/test_testing.py
index 1f31d282..a5dd12d1 100644
--- a/tests/test_testing.py
+++ b/tests/test_testing.py
@@ -7,7 +7,11 @@
 import pytest
 
 from array_api_extra._lib._backends import Backend
-from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
+from array_api_extra._lib._testing import (
+    xp_assert_close,
+    xp_assert_equal,
+    xp_assert_less,
+)
 from array_api_extra._lib._utils._compat import (
     array_namespace,
     is_dask_namespace,
@@ -23,6 +27,7 @@
     "func",
     [
         xp_assert_equal,
+        xp_assert_less,
         pytest.param(
             xp_assert_close,
             marks=pytest.mark.xfail_xp_backend(
@@ -33,7 +38,8 @@
 )
 
 
-@param_assert_equal_close
+@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype", strict=False)
+@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
 def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]):  # type: ignore[explicit-any]
     func(xp.asarray(0), xp.asarray(0))
     func(xp.asarray([1, 2]), xp.asarray([1, 2]))
@@ -53,8 +59,8 @@ def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]):  #
 
 @pytest.mark.skip_xp_backend(Backend.NUMPY, reason="test other ns vs. numpy")
 @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="test other ns vs. numpy")
-@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
-def test_assert_close_equal_namespace(xp: ModuleType, func: Callable[..., None]):  # type: ignore[explicit-any]
+@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
+def test_assert_close_equal_less_namespace(xp: ModuleType, func: Callable[..., None]):  # type: ignore[explicit-any]
     with pytest.raises(AssertionError, match="namespaces do not match"):
         func(xp.asarray(0), np.asarray(0))
     with pytest.raises(TypeError, match="Unrecognized array input"):
@@ -65,7 +71,7 @@ def test_assert_close_equal_namespace(xp: ModuleType, func: Callable[..., None])
 
 @param_assert_equal_close
 @pytest.mark.parametrize("check_shape", [False, True])
-def test_assert_close_equal_shape(  # type: ignore[explicit-any]
+def test_assert_close_equal_less_shape(  # type: ignore[explicit-any]
     xp: ModuleType,
     func: Callable[..., None],
     check_shape: bool,
@@ -76,12 +82,12 @@ def test_assert_close_equal_shape(  # type: ignore[explicit-any]
         else nullcontext()
     )
     with context:
-        func(xp.asarray([0, 0]), xp.asarray(0), check_shape=check_shape)
+        func(xp.asarray([xp.nan, xp.nan]), xp.asarray(xp.nan), check_shape=check_shape)
 
 
 @param_assert_equal_close
 @pytest.mark.parametrize("check_dtype", [False, True])
-def test_assert_close_equal_dtype(  # type: ignore[explicit-any]
+def test_assert_close_equal_less_dtype(  # type: ignore[explicit-any]
     xp: ModuleType,
     func: Callable[..., None],
     check_dtype: bool,
@@ -92,12 +98,17 @@ def test_assert_close_equal_dtype(  # type: ignore[explicit-any]
         else nullcontext()
     )
     with context:
-        func(xp.asarray(0.0), xp.asarray(0), check_dtype=check_dtype)
+        func(
+            xp.asarray(xp.nan, dtype=xp.float32),
+            xp.asarray(xp.nan, dtype=xp.float64),
+            check_dtype=check_dtype,
+        )
 
 
-@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
+@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
 @pytest.mark.parametrize("check_scalar", [False, True])
-def test_assert_close_equal_scalar(  # type: ignore[explicit-any]
+def test_assert_close_equal_less_scalar(  # type: ignore[explicit-any]
+    xp: ModuleType,
     func: Callable[..., None],
     check_scalar: bool,
 ):
@@ -107,7 +118,7 @@ def test_assert_close_equal_scalar(  # type: ignore[explicit-any]
         else nullcontext()
     )
     with context:
-        func(np.asarray(0), np.asarray(0)[()], check_scalar=check_scalar)
+        func(np.asarray(xp.nan), np.asarray(xp.nan)[()], check_scalar=check_scalar)
 
 
 @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
@@ -121,9 +132,18 @@ def test_assert_close_tolerance(xp: ModuleType):
         xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=1)
 
 
-@param_assert_equal_close
+def test_assert_less_basic(xp: ModuleType):
+    xp_assert_less(xp.asarray(-1), xp.asarray(0))
+    xp_assert_less(xp.asarray([1, 2]), xp.asarray([2, 3]))
+    with pytest.raises(AssertionError):
+        xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]))
+    with pytest.raises(AssertionError, match="hello"):
+        xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]), err_msg="hello")
+
+
 @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
 @pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
+@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
 def test_assert_close_equal_none_shape(xp: ModuleType, func: Callable[..., None]):  # type: ignore[explicit-any]
     """On Dask and other lazy backends, test that a shape with NaN's or None's
     can be compared to a real shape.

From cb3c2d6f184127b82e98c3300dd421b875dac0af Mon Sep 17 00:00:00 2001
From: Guido Imperiale <crusaderky@gmail.com>
Date: Fri, 25 Apr 2025 17:27:30 +0100
Subject: [PATCH 3/6] Rework prepare_for_test (#2)

---
 src/array_api_extra/_lib/_testing.py | 81 +++++++++++++++-------------
 tests/test_testing.py                |  9 +++-
 2 files changed, 51 insertions(+), 39 deletions(-)

diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py
index c7f700ed..7b5874d6 100644
--- a/src/array_api_extra/_lib/_testing.py
+++ b/src/array_api_extra/_lib/_testing.py
@@ -7,7 +7,7 @@
 
 import math
 from types import ModuleType
-from typing import cast
+from typing import Any, cast
 
 import numpy as np
 import pytest
@@ -17,13 +17,15 @@
     is_array_api_strict_namespace,
     is_cupy_namespace,
     is_dask_namespace,
+    is_jax_namespace,
     is_numpy_namespace,
     is_pydata_sparse_namespace,
     is_torch_namespace,
+    to_device,
 )
-from ._utils._typing import Array
+from ._utils._typing import Array, Device
 
-__all__ = ["xp_assert_close", "xp_assert_equal"]
+__all__ = ["as_numpy_array", "xp_assert_close", "xp_assert_equal", "xp_assert_less"]
 
 
 def _check_ns_shape_dtype(
@@ -81,23 +83,28 @@ def _check_ns_shape_dtype(
     return desired_xp
 
 
-def _prepare_for_test(array: Array, xp: ModuleType) -> Array:
+def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]:  # type: ignore[explicit-any]
     """
-    Ensure that the array can be compared with np.testing.
-
-    This involves transferring it from GPU to CPU memory, densifying it, etc.
+    Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
     """
-    if is_torch_namespace(xp):
-        return np.asarray(array.cpu())  # type: ignore[attr-defined, return-value]  # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportReturnType]
+    if is_cupy_namespace(xp):
+        return xp.asnumpy(array)
     if is_pydata_sparse_namespace(xp):
         return array.todense()  # type: ignore[attr-defined]  # pyright: ignore[reportAttributeAccessIssue]
+
+    if is_torch_namespace(xp):
+        array = to_device(array, "cpu")
     if is_array_api_strict_namespace(xp):
-        # Note: we deliberately did not add a `.to_device` method in _typing.pyi
-        # even if it is required by the standard as many backends don't support it
-        return array.to_device(xp.Device("CPU_DEVICE"))  # type: ignore[attr-defined]  # pyright: ignore[reportAttributeAccessIssue]
-    if is_cupy_namespace(xp):
-        return xp.asnumpy(array)
-    return array
+        cpu: Device = xp.Device("CPU_DEVICE")
+        array = to_device(array, cpu)
+    if is_jax_namespace(xp):
+        import jax
+
+        # Note: only needed if the transfer guard is enabled
+        cpu = cast(Device, jax.devices("cpu")[0])
+        array = to_device(array, cpu)
+
+    return np.asarray(array)
 
 
 def xp_assert_equal(
@@ -132,9 +139,9 @@ def xp_assert_equal(
     numpy.testing.assert_array_equal : Similar function for NumPy arrays.
     """
     xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
-    actual = _prepare_for_test(actual, xp)
-    desired = _prepare_for_test(desired, xp)
-    np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
+    actual_np = as_numpy_array(actual, xp=xp)
+    desired_np = as_numpy_array(desired, xp=xp)
+    np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg)
 
 
 def xp_assert_less(
@@ -167,9 +174,9 @@ def xp_assert_less(
     numpy.testing.assert_array_equal : Similar function for NumPy arrays.
     """
     xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
-    x = _prepare_for_test(x, xp)
-    y = _prepare_for_test(y, xp)
-    np.testing.assert_array_less(x, y, err_msg=err_msg)  # type: ignore[call-overload]
+    x_np = as_numpy_array(x, xp=xp)
+    y_np = as_numpy_array(y, xp=xp)
+    np.testing.assert_array_less(x_np, y_np, err_msg=err_msg)
 
 
 def xp_assert_close(
@@ -216,23 +223,21 @@ def xp_assert_close(
     """
     xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
 
-    floating = xp.isdtype(actual.dtype, ("real floating", "complex floating"))
-    if rtol is None and floating:
-        # multiplier of 4 is used as for `np.float64` this puts the default `rtol`
-        # roughly half way between sqrt(eps) and the default for
-        # `numpy.testing.assert_allclose`, 1e-7
-        rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4
-    elif rtol is None:
-        rtol = 1e-7
-
-    actual = _prepare_for_test(actual, xp)
-    desired = _prepare_for_test(desired, xp)
-
-    # JAX/Dask arrays work directly with `np.testing`
-    np.testing.assert_allclose(  # type: ignore[call-overload]  # pyright: ignore[reportCallIssue]
-        actual,  # pyright: ignore[reportArgumentType]
-        desired,  # pyright: ignore[reportArgumentType]
-        rtol=rtol,
+    if rtol is None:
+        if xp.isdtype(actual.dtype, ("real floating", "complex floating")):
+            # multiplier of 4 is used as for `np.float64` this puts the default `rtol`
+            # roughly half way between sqrt(eps) and the default for
+            # `numpy.testing.assert_allclose`, 1e-7
+            rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4
+        else:
+            rtol = 1e-7
+
+    actual_np = as_numpy_array(actual, xp=xp)
+    desired_np = as_numpy_array(desired, xp=xp)
+    np.testing.assert_allclose(  # pyright: ignore[reportCallIssue]
+        actual_np,
+        desired_np,
+        rtol=rtol,  # pyright: ignore[reportArgumentType]
         atol=atol,
         err_msg=err_msg,
     )
diff --git a/tests/test_testing.py b/tests/test_testing.py
index 22291b65..97585c96 100644
--- a/tests/test_testing.py
+++ b/tests/test_testing.py
@@ -8,6 +8,7 @@
 
 from array_api_extra._lib._backends import Backend
 from array_api_extra._lib._testing import (
+    as_numpy_array,
     xp_assert_close,
     xp_assert_equal,
     xp_assert_less,
@@ -17,7 +18,7 @@
     is_dask_namespace,
     is_jax_namespace,
 )
-from array_api_extra._lib._utils._typing import Array
+from array_api_extra._lib._utils._typing import Array, Device
 from array_api_extra.testing import lazy_xp_function
 
 # mypy: disable-error-code=decorated-any
@@ -38,6 +39,12 @@
 )
 
 
+def test_as_numpy_array(xp: ModuleType, device: Device):
+    x = xp.asarray([1, 2, 3], device=device)
+    y = as_numpy_array(x, xp=xp)
+    assert isinstance(y, np.ndarray)
+
+
 @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype", strict=False)
 @pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
 def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]):  # type: ignore[explicit-any]

From 7347366c8dd0011014d329fdd61d9f1ededc52a3 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Fri, 25 Apr 2025 17:40:37 +0100
Subject: [PATCH 4/6] Fix failures in #267

---
 src/array_api_extra/_lib/_testing.py | 2 ++
 tests/test_funcs.py                  | 1 -
 2 files changed, 2 insertions(+), 1 deletion(-)

diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py
index 7b5874d6..5645e398 100644
--- a/src/array_api_extra/_lib/_testing.py
+++ b/src/array_api_extra/_lib/_testing.py
@@ -5,6 +5,8 @@
 See also ..testing for public testing utilities.
 """
 
+from __future__ import annotations
+
 import math
 from types import ModuleType
 from typing import Any, cast
diff --git a/tests/test_funcs.py b/tests/test_funcs.py
index 0cee0b4d..5f34bd60 100644
--- a/tests/test_funcs.py
+++ b/tests/test_funcs.py
@@ -196,7 +196,6 @@ def test_device(self, xp: ModuleType, device: Device):
         y = apply_where(x % 2 == 0, x, self.f1, fill_value=x)
         assert get_device(y) == device
 
-    @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
     @pytest.mark.filterwarnings("ignore::RuntimeWarning")  # overflows, etc.
     @hypothesis.settings(
         # The xp and library fixtures are not regenerated between hypothesis iterations

From 7cf08cc3ee691155e060a601d395e3496c2a1578 Mon Sep 17 00:00:00 2001
From: Matt Haberland <mhaberla@calpoly.edu>
Date: Mon, 12 May 2025 07:37:45 -0700
Subject: [PATCH 5/6] Update tests/test_testing.py

[skip ci]

Co-authored-by: Lucas Colley <lucas.colley8@gmail.com>
---
 tests/test_testing.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/tests/test_testing.py b/tests/test_testing.py
index 97585c96..adb77d20 100644
--- a/tests/test_testing.py
+++ b/tests/test_testing.py
@@ -89,6 +89,7 @@ def test_assert_close_equal_less_shape(  # type: ignore[explicit-any]
         else nullcontext()
     )
     with context:
+        # note: NaNs are handled by all 3 checks
         func(xp.asarray([xp.nan, xp.nan]), xp.asarray(xp.nan), check_shape=check_shape)
 
 

From 868995f42d0b345a6c2048947830afcc5b0f3484 Mon Sep 17 00:00:00 2001
From: Matt Haberland <mhaberla@calpoly.edu>
Date: Mon, 12 May 2025 07:54:39 -0700
Subject: [PATCH 6/6] Update _testing.py

---
 src/array_api_extra/_lib/_testing.py | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py
index 5645e398..c60cf466 100644
--- a/src/array_api_extra/_lib/_testing.py
+++ b/src/array_api_extra/_lib/_testing.py
@@ -46,6 +46,11 @@ def _check_ns_shape_dtype(
         The array produced by the tested function.
     desired : Array
         The expected array (typically hardcoded).
+    check_dtype, check_shape : bool, default: True
+        Whether to check agreement between actual and desired dtypes and shapes
+    check_scalar : bool, default: False
+        NumPy only: whether to check agreement between actual and desired types -
+        0d array vs scalar.
 
     Returns
     -------