Skip to content

Commit 75b32ce

Browse files
🏷️ ufunc annotations for copysign, heaviside, logaddexp[2], and nextafter (#376)
* 🏷️ stub `copysign`, `heaviside`, `logaddexp`, `logaddexp2`, `nextafter` * 🏷️ stub `copysign`, `heaviside`, `logaddexp`, `logaddexp2`, `nextafter` * 🧪 update reject test * 🐛 update src/numpy-stubs/_core/umath.pyi Co-Authored-By: Joren Hammudoglu <[email protected]> * 🧪 update test Co-authored-by: Joren Hammudoglu <[email protected]> * 🐛 update src/numpy-stubs/_core/umath.pyi Co-authored-by: Joren Hammudoglu <[email protected]> --------- Co-authored-by: Joren Hammudoglu <[email protected]>
1 parent 54b5de7 commit 75b32ce

File tree

3 files changed

+139
-5
lines changed

3 files changed

+139
-5
lines changed

src/numpy-stubs/_core/umath.pyi

Lines changed: 97 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ from typing import (
1414
from typing_extensions import Never, TypeAliasType, TypeVar, Unpack
1515

1616
import numpy as np
17+
from _numtype import CoFloat64_nd, CoFloating_nd, ToFloat64_1nd
1718
from numpy import _CastingKind, _OrderKACF # noqa: ICN003
1819
from numpy._typing import (
1920
ArrayLike,
@@ -746,6 +747,97 @@ class _Call21Bool(Protocol):
746747
**kwds: Unpack[_Kwargs3],
747748
) -> Any: ...
748749

750+
@type_check_only
751+
class _Call21Float(Protocol):
752+
@overload # (float, float) -> float64
753+
def __call__(
754+
self,
755+
x1: float,
756+
x2: float,
757+
/,
758+
out: None = None,
759+
*,
760+
dtype: None = None,
761+
**kwds: Unpack[_Kwargs3],
762+
) -> np.float64: ...
763+
@overload # (scalar, scalar) -> float
764+
def __call__(
765+
self,
766+
x1: _FloatLike_co,
767+
x2: _FloatLike_co,
768+
/,
769+
out: None = None,
770+
*,
771+
dtype: _DTypeLikeFloat | None = None,
772+
**kwds: Unpack[_Kwargs3],
773+
) -> np.floating: ...
774+
@overload # (array-like, array-like, out: T) -> T
775+
def __call__(
776+
self,
777+
x1: _ArrayLikeFloat_co,
778+
x2: _ArrayLikeFloat_co,
779+
/,
780+
out: _Out1[_ArrayT],
781+
*,
782+
dtype: None = None,
783+
**kwds: Unpack[_Kwargs3],
784+
) -> _ArrayT: ...
785+
@overload # (array-like, array) -> Array[float64]
786+
def __call__(
787+
self,
788+
x1: CoFloat64_nd,
789+
x2: ToFloat64_1nd,
790+
/,
791+
out: None = None,
792+
*,
793+
dtype: None = None,
794+
**kwds: Unpack[_Kwargs3],
795+
) -> NDArray[np.float64]: ...
796+
@overload # (array, array-like) -> Array[float64]
797+
def __call__(
798+
self,
799+
x1: ToFloat64_1nd,
800+
x2: CoFloat64_nd,
801+
/,
802+
out: None = None,
803+
*,
804+
dtype: None = None,
805+
**kwds: Unpack[_Kwargs3],
806+
) -> NDArray[np.float64]: ...
807+
@overload # (array-like, array) -> Array[float]
808+
def __call__(
809+
self,
810+
x1: _ArrayLikeFloat_co,
811+
x2: NDArray[np.floating] | _NestedSequence[float],
812+
/,
813+
out: None = None,
814+
*,
815+
dtype: _DTypeLikeFloat | None = None,
816+
**kwds: Unpack[_Kwargs3],
817+
) -> NDArray[np.floating]: ...
818+
@overload # (array, array-like) -> Array[float]
819+
def __call__(
820+
self,
821+
x1: NDArray[np.floating] | _NestedSequence[float],
822+
x2: _ArrayLikeFloat_co,
823+
/,
824+
out: None = None,
825+
*,
826+
dtype: _DTypeLikeFloat | None = None,
827+
**kwds: Unpack[_Kwargs3],
828+
) -> NDArray[np.floating]: ...
829+
@overload # (array-like, array-like) -> Array[float] | float
830+
def __call__(
831+
self,
832+
x1: CoFloating_nd,
833+
x2: CoFloating_nd,
834+
/,
835+
out: _Out1[_AnyArray] | None = None,
836+
*,
837+
dtype: _DTypeLikeFloat | None = None,
838+
**kwds: Unpack[_Kwargs3],
839+
) -> Any: ...
840+
749841
@type_check_only
750842
class _Call21Logical(Protocol):
751843
@overload # (scalar, scalar, dtype: np.object_) -> np.object_
@@ -1550,11 +1642,11 @@ ldexp: Final[_ufunc_2_1] = ...
15501642
float_power: Final[_ufunc_2_1] = ...
15511643

15521644
# {[f]}, $1 -> $1
1553-
copysign: Final[_ufunc_2_1] = ...
1554-
heaviside: Final[_ufunc_2_1] = ...
1555-
logaddexp: Final[_ufunc_2_1] = ...
1556-
logaddexp2: Final[_ufunc_2_1] = ...
1557-
nextafter: Final[_ufunc_2_1] = ...
1645+
copysign: Final[_ufunc_2_1[_Call21Float]] = ...
1646+
heaviside: Final[_ufunc_2_1[_Call21Float]] = ...
1647+
logaddexp: Final[_ufunc_2_1[_Call21Float]] = ...
1648+
logaddexp2: Final[_ufunc_2_1[_Call21Float]] = ...
1649+
nextafter: Final[_ufunc_2_1[_Call21Float]] = ...
15581650

15591651
# {[f]O}, $1 -> $1
15601652
arctan2: Final[_ufunc_2_1] = ...

test/static/accept/ufuncs.pyi

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,38 @@ assert_type(np.radians(f8), np.float64)
130130
assert_type(np.radians(f8, dtype=np.float64), np.float64)
131131
assert_type(np.radians(AR_f8), npt.NDArray[np.float64])
132132
assert_type(np.radians(AR_f8, out=AR_f8), npt.NDArray[np.float64])
133+
134+
assert_type(np.copysign(f8, f8), np.float64)
135+
assert_type(np.copysign(AR_f4, f8), npt.NDArray[np.floating])
136+
assert_type(np.copysign(AR_f8, f8, out=AR_f8), npt.NDArray[np.float64])
137+
assert_type(np.copysign(f8, AR_f8), npt.NDArray[np.float64])
138+
assert_type(np.copysign(AR_f8, AR_f8), npt.NDArray[np.float64])
139+
assert_type(np.copysign(AR_f8, AR_f8, out=AR_f8), npt.NDArray[np.float64])
140+
141+
assert_type(np.heaviside(f8, f8), np.float64)
142+
assert_type(np.heaviside(AR_f8, f8), npt.NDArray[np.float64])
143+
assert_type(np.heaviside(AR_f8, f8, out=AR_f8), npt.NDArray[np.float64])
144+
assert_type(np.heaviside(f8, AR_f8), npt.NDArray[np.float64])
145+
assert_type(np.heaviside(AR_f8, AR_f8), npt.NDArray[np.float64])
146+
assert_type(np.heaviside(AR_f8, AR_f8, out=AR_f8), npt.NDArray[np.float64])
147+
148+
assert_type(np.logaddexp(f8, f8), np.float64)
149+
assert_type(np.logaddexp(AR_f8, f8), npt.NDArray[np.float64])
150+
assert_type(np.logaddexp(AR_f8, f8, out=AR_f8), npt.NDArray[np.float64])
151+
assert_type(np.logaddexp(f8, AR_f8), npt.NDArray[np.float64])
152+
assert_type(np.logaddexp(AR_f8, AR_f8), npt.NDArray[np.float64])
153+
assert_type(np.logaddexp(AR_f8, AR_f8, out=AR_f8), npt.NDArray[np.float64])
154+
155+
assert_type(np.logaddexp2(f8, f8), np.float64)
156+
assert_type(np.logaddexp2(AR_f8, f8), npt.NDArray[np.float64])
157+
assert_type(np.logaddexp2(AR_f8, f8, out=AR_f8), npt.NDArray[np.float64])
158+
assert_type(np.logaddexp2(f8, AR_f8), npt.NDArray[np.float64])
159+
assert_type(np.logaddexp2(AR_f8, AR_f8), npt.NDArray[np.float64])
160+
assert_type(np.logaddexp2(AR_f8, AR_f8, out=AR_f8), npt.NDArray[np.float64])
161+
162+
assert_type(np.nextafter(f8, f8), np.float64)
163+
assert_type(np.nextafter(AR_f8, f8), npt.NDArray[np.float64])
164+
assert_type(np.nextafter(AR_f8, f8, out=AR_f8), npt.NDArray[np.float64])
165+
assert_type(np.nextafter(f8, AR_f8), npt.NDArray[np.float64])
166+
assert_type(np.nextafter(AR_f8, AR_f8), npt.NDArray[np.float64])
167+
assert_type(np.nextafter(AR_f8, AR_f8, out=AR_f8), npt.NDArray[np.float64])

test/static/reject/ufuncs.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import numpy.typing as npt
44
i8: np.int64
55
AR_f8: npt.NDArray[np.float64]
66
dt64: np.datetime64
7+
AR_dt64: npt.NDArray[np.datetime64]
78

89
np.sin.nin + "foo" # type: ignore[operator] # pyright: ignore[reportOperatorIssue]
910

@@ -84,3 +85,9 @@ np.rad2deg(AR_f8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ign
8485

8586
np.radians(dt64) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
8687
np.radians(AR_f8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
88+
89+
np.copysign(i8, i8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
90+
np.heaviside(i8, i8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
91+
np.logaddexp(i8, i8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
92+
np.logaddexp2(i8, i8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
93+
np.nextafter(i8, i8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]

0 commit comments

Comments
 (0)