Skip to content

REL: prepare v0.8.0 release #312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,958 changes: 977 additions & 981 deletions pixi.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from ._lib._lazy import lazy_apply

__version__ = "0.8.0.dev0"
__version__ = "0.8.0"

# pylint: disable=duplicate-code
__all__ = [
Expand Down
16 changes: 12 additions & 4 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,9 @@ def meta_namespace(
return array_namespace(*metas)


def capabilities(xp: ModuleType, *, device: Device | None = None) -> dict[str, int]:
def capabilities(
xp: ModuleType, *, device: Device | None = None
) -> dict[str, int | None]:
"""
Return patched ``xp.__array_namespace_info__().capabilities()``.

Expand All @@ -322,7 +324,11 @@ def capabilities(xp: ModuleType, *, device: Device | None = None) -> dict[str, i
"""
if is_pydata_sparse_namespace(xp):
# No __array_namespace_info__(); no indexing by sparse arrays
return {"boolean indexing": False, "data-dependent shapes": True}
return {
"boolean indexing": False,
"data-dependent shapes": True,
"max dimensions": None,
}
out = xp.__array_namespace_info__().capabilities()
if is_jax_namespace(xp) and out["boolean indexing"]:
# FIXME https://github.com/jax-ml/jax/issues/27418
Expand Down Expand Up @@ -418,7 +424,9 @@ class Pickler(pickle.Pickler): # numpydoc ignore=GL08
"""

@override
def persistent_id(self, obj: object) -> Literal[0, 1, None]: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
def persistent_id(
self, obj: object
) -> Literal[0, 1, None]: # numpydoc ignore=GL08
if isinstance(obj, cls):
instances.append(obj) # type: ignore[arg-type]
return 0
Expand Down Expand Up @@ -483,7 +491,7 @@ class Unpickler(pickle.Unpickler): # numpydoc ignore=GL08
"""Mirror of the overridden Pickler in pickle_flatten."""

@override
def persistent_load(self, pid: Literal[0, 1]) -> object: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
def persistent_load(self, pid: Literal[0, 1]) -> object: # numpydoc ignore=GL08
try:
return next(iters[pid])
except StopIteration as e:
Expand Down
13 changes: 8 additions & 5 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,6 @@ def test_none(self, args: tuple[tuple[float | None, ...], ...]):
assert actual == expect


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
class TestCov:
def test_basic(self, xp: ModuleType):
xp_assert_close(
Expand All @@ -417,6 +416,7 @@ def test_complex(self, xp: ModuleType):
expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128)
xp_assert_close(actual, expect)

@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="matmul with nan fillvalue")
def test_empty(self, xp: ModuleType):
with warnings.catch_warnings(record=True):
warnings.simplefilter("always", RuntimeWarning)
Expand Down Expand Up @@ -612,7 +612,6 @@ def test_xp(self, xp: ModuleType):
xp_assert_equal(y, xp.asarray([[1, 0], [0, 2]]))


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no __array_namespace_info__")
class TestDefaultDType:
def test_basic(self, xp: ModuleType):
assert default_dtype(xp) == xp.empty(0).dtype
Expand Down Expand Up @@ -697,7 +696,9 @@ def test_xp(self, xp: ModuleType):
@pytest.mark.filterwarnings( # array_api_strictest
"ignore:invalid value encountered:RuntimeWarning:array_api_strict"
)
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
@pytest.mark.filterwarnings( # sparse
"ignore:invalid value encountered:RuntimeWarning:sparse"
)
class TestIsClose:
@pytest.mark.parametrize("swap", [False, True])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -815,6 +816,7 @@ def test_bool_dtype(self, xp: ModuleType):
isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True])
)

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
def test_none_shape(self, xp: ModuleType):
a = xp.asarray([1, 5, 0])
Expand All @@ -823,6 +825,7 @@ def test_none_shape(self, xp: ModuleType):
a = a[a < 5]
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
def test_none_shape_bool(self, xp: ModuleType):
a = xp.asarray([True, True, False])
Expand Down Expand Up @@ -919,7 +922,6 @@ def test_kron_shape(
k = kron(a, b)
assert k.shape == expected_shape

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
def test_python_scalar(self, xp: ModuleType):
a = 1
# Test no dtype promotion to xp.asarray(a); use b.dtype
Expand Down Expand Up @@ -1138,8 +1140,8 @@ def test_xp(self, xp: ModuleType):
xp_assert_equal(actual, expected)


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
class TestSinc:
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no linspace")
def test_simple(self, xp: ModuleType):
xp_assert_equal(sinc(xp.asarray(0.0)), xp.asarray(1.0))
w = sinc(xp.linspace(-1, 1, 100))
Expand All @@ -1151,6 +1153,7 @@ def test_dtype(self, xp: ModuleType, x: int | complex):
with pytest.raises(ValueError, match="real floating data type"):
_ = sinc(xp.asarray(x))

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange")
def test_3d(self, xp: ModuleType):
x = xp.reshape(xp.arange(18, dtype=xp.float64), (3, 3, 2))
expected = xp.zeros((3, 3, 2), dtype=xp.float64)
Expand Down
1 change: 0 additions & 1 deletion tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def test_xp(self, xp: ModuleType):


class TestAsArrays:
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
@pytest.mark.parametrize(
("dtype", "b", "defined"),
[
Expand Down
18 changes: 6 additions & 12 deletions tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,7 @@ def test_device(self, xp: ModuleType, device: Device):


class TestAssertEqualCloseLess:
pr_assert_close = pytest.param( # pyright: ignore[reportUnannotatedClassAttribute]
xp_assert_close,
marks=pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype"),
)

@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close])
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
def test_assert_equal_close_basic(self, xp: ModuleType, func: Callable[..., None]):
func(xp.asarray(0), xp.asarray(0))
func(xp.asarray([1, 2]), xp.asarray([1, 2]))
Expand Down Expand Up @@ -75,7 +70,7 @@ def test_namespace(self, xp: ModuleType, func: Callable[..., None]):
with pytest.raises(TypeError, match="list is not a supported array type"):
func(xp.asarray([0]), [0])

@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
def test_check_shape(self, xp: ModuleType, func: Callable[..., None]):
a = xp.asarray([1] if func is xp_assert_less else [2])
b = xp.asarray(2)
Expand All @@ -90,7 +85,7 @@ def test_check_shape(self, xp: ModuleType, func: Callable[..., None]):
with pytest.raises(AssertionError, match="sizes do not match"):
func(a, d, check_shape=False)

@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
def test_check_dtype(self, xp: ModuleType, func: Callable[..., None]):
a = xp.asarray(1 if func is xp_assert_less else 2)
b = xp.asarray(2, dtype=xp.int16)
Expand All @@ -102,7 +97,7 @@ def test_check_dtype(self, xp: ModuleType, func: Callable[..., None]):
with pytest.raises(AssertionError, match="Mismatched elements"):
func(a, c, check_dtype=False)

@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
@pytest.mark.xfail_xp_backend(
Backend.SPARSE, reason="sparse [()] returns np.generic"
)
Expand All @@ -122,7 +117,6 @@ def test_check_scalar(
with pytest.raises(AssertionError, match="Mismatched elements"):
func(a, c, check_scalar=True)

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
@pytest.mark.parametrize("dtype", ["int64", "float64"])
def test_assert_close_tolerance(self, dtype: str, xp: ModuleType):
a = xp.asarray([100], dtype=getattr(xp, dtype))
Expand All @@ -145,7 +139,7 @@ def test_assert_less(self, xp: ModuleType):
with pytest.raises(AssertionError, match="Mismatched elements"):
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]))

@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
def test_none_shape(self, xp: ModuleType, func: Callable[..., None]):
Expand Down Expand Up @@ -176,7 +170,7 @@ def test_none_shape(self, xp: ModuleType, func: Callable[..., None]):
with pytest.raises(AssertionError, match="Mismatched elements"):
func(xp.asarray([4]), a)

@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
def test_device(self, xp: ModuleType, device: Device, func: Callable[..., None]):
a = xp.asarray([1] if func is xp_assert_less else [2], device=device)
b = xp.asarray([2], device=device)
Expand Down