Skip to content

Commit 4434999

Browse files
authored
🏷️ ufunc annotations for spacing (#363)
* 🏷️ stub `umath.spacing` * ✅ add test for `umath.spacing` * 🏷️ enhance annatation with float64
1 parent 7664b03 commit 4434999

File tree

3 files changed

+69
-2
lines changed

3 files changed

+69
-2
lines changed

src/numpy-stubs/_core/umath.pyi

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@ from numpy._typing import (
2121
NDArray,
2222
_ArrayLike,
2323
_ArrayLikeBool_co,
24+
_ArrayLikeFloat_co,
2425
_ArrayLikeInt_co,
2526
_ArrayLikeNumber_co,
2627
_ArrayLikeObject_co,
2728
_DTypeLike,
2829
_DTypeLikeBool,
30+
_DTypeLikeFloat,
31+
_FloatLike_co,
2932
_NestedSequence,
3033
_NumberLike_co,
3134
_ScalarLike_co,
@@ -357,6 +360,59 @@ class _Call11Bool(Protocol):
357360
**kwds: Unpack[_Kwargs2],
358361
) -> NDArray[np.bool] | np.bool: ...
359362

363+
@type_check_only
364+
class _Call11Float(Protocol):
365+
@overload # (float) -> float64
366+
def __call__(
367+
self,
368+
x: float,
369+
/,
370+
out: None = None,
371+
*,
372+
dtype: _DTypeLikeFloat | None = None,
373+
**kwds: Unpack[_Kwargs2],
374+
) -> np.float64: ...
375+
@overload # (scalar) -> float
376+
def __call__(
377+
self,
378+
x: _FloatLike_co,
379+
/,
380+
out: None = None,
381+
*,
382+
dtype: _DTypeLikeFloat | None = None,
383+
**kwds: Unpack[_Kwargs2],
384+
) -> np.floating: ...
385+
@overload # (array-like, out: T) -> T
386+
def __call__(
387+
self,
388+
x: _ArrayLikeFloat_co,
389+
/,
390+
out: _Out1[_ArrayT],
391+
*,
392+
dtype: _DTypeLikeFloat | None = None,
393+
**kwds: Unpack[_Kwargs2],
394+
) -> _ArrayT: ...
395+
@overload # (NDArray[float64] | _NestedSequence[float]) -> NDArray[float64]
396+
def __call__(
397+
self,
398+
x: NDArray[np.float64] | _NestedSequence[float],
399+
/,
400+
out: _Out1[NDArray[np.float64]] | None = None,
401+
*,
402+
dtype: _DTypeLikeFloat | None = None,
403+
**kwds: Unpack[_Kwargs2],
404+
) -> NDArray[np.float64]: ...
405+
@overload # (array) -> Array[float]
406+
def __call__(
407+
self,
408+
x: NDArray[np.floating] | _NestedSequence[np.floating],
409+
/,
410+
out: _Out1[NDArray[np.floating]] | None = None,
411+
*,
412+
dtype: _DTypeLikeFloat | None = None,
413+
**kwds: Unpack[_Kwargs2],
414+
) -> NDArray[np.floating]: ...
415+
360416
@type_check_only
361417
class _Call11Isnat(Protocol):
362418
@overload # (scalar) -> bool
@@ -1151,8 +1207,8 @@ class _ReduceAt2(Protocol):
11511207
def __call__(
11521208
self,
11531209
array: ArrayLike,
1154-
/,
11551210
indices: _ArrayLikeInt_co,
1211+
/,
11561212
axis: SupportsIndex,
11571213
dtype: _DTypeLike[_ScalarT],
11581214
out: NDArray[_ScalarT] | None = None,
@@ -1360,7 +1416,7 @@ str_len: _ufunc_1_1[_UFunc11String[np.intp]]
13601416
bitwise_count: Final[_ufunc_1_1] = ...
13611417

13621418
# {[f]} -> $1
1363-
spacing: Final[_ufunc_1_1] = ...
1419+
spacing: Final[_ufunc_1_1[_Call11Float]] = ...
13641420

13651421
# {[f]O} -> $1
13661422
cbrt: Final[_ufunc_1_1] = ...

test/static/accept/ufuncs.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ import numpy as np
55
import numpy.typing as npt
66

77
i8: np.int64
8+
f4: np.float32
89
f8: np.float64
910
dt64: np.datetime64
1011
td64: np.timedelta64
12+
AR_f4: npt.NDArray[np.float32]
1113
AR_f8: npt.NDArray[np.float64]
1214
AR_i8: npt.NDArray[np.int64]
1315
AR_bool: npt.NDArray[np.bool_]
@@ -88,3 +90,9 @@ assert_type(np.logical_xor(AR_bool, AR_bool, out=AR_bool), npt.NDArray[np.bool_]
8890
assert_type(np.logical_xor(AR_i8, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
8991
assert_type(np.logical_xor(AR_bool, AR_i8), npt.NDArray[np.bool_])
9092
assert_type(np.logical_xor(AR_bool, AR_bool, dtype=np.object_), npt.NDArray[np.object_])
93+
94+
assert_type(np.spacing(f4), np.floating)
95+
assert_type(np.spacing(f8), np.float64)
96+
assert_type(np.spacing(AR_f8), npt.NDArray[np.float64])
97+
assert_type(np.spacing(AR_f4), npt.NDArray[np.floating])
98+
assert_type(np.spacing(AR_f8, out=AR_f8), npt.NDArray[np.float64])

test/static/reject/ufuncs.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,6 @@ np.logical_or(dt64, dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyri
6060

6161
np.logical_xor(dt64, dt64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
6262
np.logical_xor(dt64, dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
63+
64+
np.spacing(dt64) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
65+
np.spacing(AR_f8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]

0 commit comments

Comments
 (0)