diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index c7f700ed..7b5874d6 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -7,7 +7,7 @@ import math from types import ModuleType -from typing import cast +from typing import Any, cast import numpy as np import pytest @@ -17,13 +17,15 @@ is_array_api_strict_namespace, is_cupy_namespace, is_dask_namespace, + is_jax_namespace, is_numpy_namespace, is_pydata_sparse_namespace, is_torch_namespace, + to_device, ) -from ._utils._typing import Array +from ._utils._typing import Array, Device -__all__ = ["xp_assert_close", "xp_assert_equal"] +__all__ = ["as_numpy_array", "xp_assert_close", "xp_assert_equal", "xp_assert_less"] def _check_ns_shape_dtype( @@ -81,23 +83,28 @@ def _check_ns_shape_dtype( return desired_xp -def _prepare_for_test(array: Array, xp: ModuleType) -> Array: +def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any] """ - Ensure that the array can be compared with np.testing. - - This involves transferring it from GPU to CPU memory, densifying it, etc. + Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards. """ - if is_torch_namespace(xp): - return np.asarray(array.cpu()) # type: ignore[attr-defined, return-value] # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportReturnType] + if is_cupy_namespace(xp): + return xp.asnumpy(array) if is_pydata_sparse_namespace(xp): return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + + if is_torch_namespace(xp): + array = to_device(array, "cpu") if is_array_api_strict_namespace(xp): - # Note: we deliberately did not add a `.to_device` method in _typing.pyi - # even if it is required by the standard as many backends don't support it - return array.to_device(xp.Device("CPU_DEVICE")) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - if is_cupy_namespace(xp): - return xp.asnumpy(array) - return array + cpu: Device = xp.Device("CPU_DEVICE") + array = to_device(array, cpu) + if is_jax_namespace(xp): + import jax + + # Note: only needed if the transfer guard is enabled + cpu = cast(Device, jax.devices("cpu")[0]) + array = to_device(array, cpu) + + return np.asarray(array) def xp_assert_equal( @@ -132,9 +139,9 @@ def xp_assert_equal( numpy.testing.assert_array_equal : Similar function for NumPy arrays. """ xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar) - actual = _prepare_for_test(actual, xp) - desired = _prepare_for_test(desired, xp) - np.testing.assert_array_equal(actual, desired, err_msg=err_msg) + actual_np = as_numpy_array(actual, xp=xp) + desired_np = as_numpy_array(desired, xp=xp) + np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg) def xp_assert_less( @@ -167,9 +174,9 @@ def xp_assert_less( numpy.testing.assert_array_equal : Similar function for NumPy arrays. """ xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar) - x = _prepare_for_test(x, xp) - y = _prepare_for_test(y, xp) - np.testing.assert_array_less(x, y, err_msg=err_msg) # type: ignore[call-overload] + x_np = as_numpy_array(x, xp=xp) + y_np = as_numpy_array(y, xp=xp) + np.testing.assert_array_less(x_np, y_np, err_msg=err_msg) def xp_assert_close( @@ -216,23 +223,21 @@ def xp_assert_close( """ xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar) - floating = xp.isdtype(actual.dtype, ("real floating", "complex floating")) - if rtol is None and floating: - # multiplier of 4 is used as for `np.float64` this puts the default `rtol` - # roughly half way between sqrt(eps) and the default for - # `numpy.testing.assert_allclose`, 1e-7 - rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4 - elif rtol is None: - rtol = 1e-7 - - actual = _prepare_for_test(actual, xp) - desired = _prepare_for_test(desired, xp) - - # JAX/Dask arrays work directly with `np.testing` - np.testing.assert_allclose( # type: ignore[call-overload] # pyright: ignore[reportCallIssue] - actual, # pyright: ignore[reportArgumentType] - desired, # pyright: ignore[reportArgumentType] - rtol=rtol, + if rtol is None: + if xp.isdtype(actual.dtype, ("real floating", "complex floating")): + # multiplier of 4 is used as for `np.float64` this puts the default `rtol` + # roughly half way between sqrt(eps) and the default for + # `numpy.testing.assert_allclose`, 1e-7 + rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4 + else: + rtol = 1e-7 + + actual_np = as_numpy_array(actual, xp=xp) + desired_np = as_numpy_array(desired, xp=xp) + np.testing.assert_allclose( # pyright: ignore[reportCallIssue] + actual_np, + desired_np, + rtol=rtol, # pyright: ignore[reportArgumentType] atol=atol, err_msg=err_msg, ) diff --git a/tests/test_testing.py b/tests/test_testing.py index 22291b65..97585c96 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -8,6 +8,7 @@ from array_api_extra._lib._backends import Backend from array_api_extra._lib._testing import ( + as_numpy_array, xp_assert_close, xp_assert_equal, xp_assert_less, @@ -17,7 +18,7 @@ is_dask_namespace, is_jax_namespace, ) -from array_api_extra._lib._utils._typing import Array +from array_api_extra._lib._utils._typing import Array, Device from array_api_extra.testing import lazy_xp_function # mypy: disable-error-code=decorated-any @@ -38,6 +39,12 @@ ) +def test_as_numpy_array(xp: ModuleType, device: Device): + x = xp.asarray([1, 2, 3], device=device) + y = as_numpy_array(x, xp=xp) + assert isinstance(y, np.ndarray) + + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype", strict=False) @pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close]) def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]