Skip to content

Commit 4ea617b

Browse files
committed
fixes
1 parent 6ad0d17 commit 4ea617b

File tree

4 files changed

+12
-13
lines changed

4 files changed

+12
-13
lines changed

array_api_strict/_array_object.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from collections.abc import Iterator
2121
from enum import IntEnum
2222
from types import ModuleType
23-
from typing import Any, Final, Literal, SupportsIndex
23+
from typing import TYPE_CHECKING, Any, Final, Literal, SupportsIndex
2424

2525
import numpy as np
2626
import numpy.typing as npt
@@ -213,7 +213,7 @@ def _check_allowed_dtypes(
213213

214214
if self.dtype not in _dtype_categories[dtype_category]:
215215
raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
216-
if isinstance(other, (int, complex, float, bool)):
216+
if isinstance(other, (bool, int, float, complex)):
217217
other = self._promote_scalar(other)
218218
elif isinstance(other, Array):
219219
if other.dtype not in _dtype_categories[dtype_category]:
@@ -243,7 +243,7 @@ def _check_allowed_dtypes(
243243

244244
def _check_device(self, other: Array | complex) -> None:
245245
"""Check that other is on a device compatible with the current array"""
246-
if isinstance(other, (int, complex, float, bool)):
246+
if isinstance(other, (bool, int, float, complex)):
247247
return
248248
elif isinstance(other, Array):
249249
if self.device != other.device:
@@ -1098,7 +1098,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
10981098
res = self._array.__rmatmul__(other._array)
10991099
return self.__class__._new(res, device=self.device)
11001100

1101-
def __imod__(self, other: Array | complex, /) -> Array:
1101+
def __imod__(self, other: Array | float, /) -> Array:
11021102
"""
11031103
Performs the operation __imod__.
11041104
"""

array_api_strict/_creation_functions.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Generator
44
from contextlib import contextmanager
55
from enum import Enum
6-
from typing import TYPE_CHECKING, Literal, cast
6+
from typing import TYPE_CHECKING, Literal
77

88
import numpy as np
99

@@ -18,6 +18,7 @@
1818
# Circular import
1919
from ._array_object import Array, Device
2020

21+
2122
class Undef(Enum):
2223
UNDEF = 0
2324

@@ -316,8 +317,7 @@ def linspace(
316317
)
317318

318319

319-
# Note: indexing was 'str' in <=2024.12
320-
def meshgrid(*arrays: Array, indexing: str = "xy") -> list[Array]:
320+
def meshgrid(*arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> list[Array]:
321321
"""
322322
Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.
323323
@@ -340,11 +340,9 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> list[Array]:
340340
else:
341341
device = None
342342

343-
np_indexing = cast(Literal["xy", "ij"], indexing)
344-
345343
return [
346344
Array._new(array, device=device)
347-
for array in np.meshgrid(*[a._array for a in arrays], indexing=np_indexing)
345+
for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
348346
]
349347

350348

array_api_strict/_elementwise_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ def ceil(x: Array, /) -> Array:
270270
def clip(
271271
x: Array,
272272
/,
273-
min: Array | complex | None = None,
274-
max: Array | complex | None = None,
273+
min: Array | float | None = None,
274+
max: Array | float | None = None,
275275
) -> Array:
276276
"""
277277
Array API compatible wrapper for :py:func:`np.clip <numpy.clip>`.

array_api_strict/_flags.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
"2021.12",
3838
"2022.12",
3939
"2023.12",
40-
"2024.12"
40+
"2024.12",
4141
)
4242

4343
draft_version = "2025.12"
@@ -390,6 +390,7 @@ def set_flags_from_environment() -> None:
390390

391391
# Decorators
392392

393+
393394
def requires_api_version(version: str) -> Callable[[_CallableT], _CallableT]:
394395
def decorator(func: Callable[P, T]) -> Callable[P, T]:
395396
@functools.wraps(func)

0 commit comments

Comments
 (0)