Skip to content

Commit 951bbb0

Browse files
committed
🏷️ stub copysign, heaviside, logaddexp, logaddexp2, nextafter
1 parent 3e87a93 commit 951bbb0

File tree

3 files changed

+138
-5
lines changed

3 files changed

+138
-5
lines changed

src/numpy-stubs/_core/umath.pyi

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,78 @@ class _Call21Bool(Protocol):
746746
**kwds: Unpack[_Kwargs3],
747747
) -> np.bool | NDArray[np.bool]: ...
748748

749+
@type_check_only
750+
class _Call21Float(Protocol):
751+
@overload # (float, float) -> float64
752+
def __call__(
753+
self,
754+
x1: float,
755+
x2: float,
756+
/,
757+
out: None = None,
758+
dtype: _DTypeLikeFloat | None = None,
759+
**kwds: Unpack[_Kwargs3],
760+
) -> np.float64: ...
761+
@overload # (scalar, scalar) -> float
762+
def __call__(
763+
self,
764+
x1: _FloatLike_co,
765+
x2: _FloatLike_co,
766+
/,
767+
out: None = None,
768+
dtype: _DTypeLikeFloat | None = None,
769+
**kwds: Unpack[_Kwargs3],
770+
) -> np.floating: ...
771+
@overload # (array-like, array-like, out: T) -> T
772+
def __call__(
773+
self,
774+
x1: _ArrayLikeFloat_co,
775+
x2: _ArrayLikeFloat_co,
776+
/,
777+
out: _ArrayT | tuple[_ArrayT],
778+
**kwds: Unpack[_Kwargs3],
779+
) -> _ArrayT: ...
780+
@overload # (array-like, array) -> Array[float64]
781+
def __call__(
782+
self,
783+
x1: _ArrayLikeFloat_co,
784+
x2: NDArray[np.float64] | _NestedSequence[float],
785+
/,
786+
out: _Out1[_AnyArray] | None = None,
787+
dtype: _DTypeLikeFloat | None = None,
788+
**kwds: Unpack[_Kwargs3],
789+
) -> NDArray[np.float64]: ...
790+
@overload # (array, array-like) -> Array[float64]
791+
def __call__(
792+
self,
793+
x1: NDArray[np.float64] | _NestedSequence[float],
794+
x2: _ArrayLikeFloat_co,
795+
/,
796+
out: _Out1[_AnyArray] | None = None,
797+
dtype: _DTypeLikeFloat | None = None,
798+
**kwds: Unpack[_Kwargs3],
799+
) -> NDArray[np.float64]: ...
800+
@overload # (array-like, array) -> Array[float]
801+
def __call__(
802+
self,
803+
x1: _ArrayLikeFloat_co,
804+
x2: NDArray[np.floating] | _NestedSequence[float],
805+
/,
806+
out: _Out1[_AnyArray] | None = None,
807+
dtype: _DTypeLikeFloat | None = None,
808+
**kwds: Unpack[_Kwargs3],
809+
) -> NDArray[np.floating]: ...
810+
@overload # (array, array-like) -> Array[float]
811+
def __call__(
812+
self,
813+
x1: NDArray[np.floating] | _NestedSequence[float],
814+
x2: _ArrayLikeFloat_co,
815+
/,
816+
out: _Out1[_AnyArray] | None = None,
817+
dtype: _DTypeLikeFloat | None = None,
818+
**kwds: Unpack[_Kwargs3],
819+
) -> NDArray[np.floating]: ...
820+
749821
@type_check_only
750822
class _Call21Logical(Protocol):
751823
@overload # (scalar, scalar, dtype: np.object_) -> np.object_
@@ -1550,11 +1622,11 @@ ldexp: Final[_ufunc_2_1] = ...
15501622
float_power: Final[_ufunc_2_1] = ...
15511623

15521624
# {[f]}, $1 -> $1
1553-
copysign: Final[_ufunc_2_1] = ...
1554-
heaviside: Final[_ufunc_2_1] = ...
1555-
logaddexp: Final[_ufunc_2_1] = ...
1556-
logaddexp2: Final[_ufunc_2_1] = ...
1557-
nextafter: Final[_ufunc_2_1] = ...
1625+
copysign: Final[_ufunc_2_1[_Call21Float]] = ...
1626+
heaviside: Final[_ufunc_2_1[_Call21Float]] = ...
1627+
logaddexp: Final[_ufunc_2_1[_Call21Float]] = ...
1628+
logaddexp2: Final[_ufunc_2_1[_Call21Float]] = ...
1629+
nextafter: Final[_ufunc_2_1[_Call21Float]] = ...
15581630

15591631
# {[f]O}, $1 -> $1
15601632
arctan2: Final[_ufunc_2_1] = ...

test/static/accept/ufuncs.pyi

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,38 @@ assert_type(np.radians(f8), np.float64)
130130
assert_type(np.radians(f8, dtype=np.float64), np.float64)
131131
assert_type(np.radians(AR_f8), npt.NDArray[np.float64])
132132
assert_type(np.radians(AR_f8, out=AR_f8), npt.NDArray[np.float64])
133+
134+
assert_type(np.copysign(f8, f8), np.float64)
135+
assert_type(np.copysign(AR_f4, f8), npt.NDArray[np.floating])
136+
assert_type(np.copysign(AR_f8, f8, out=AR_f8), npt.NDArray[np.float64])
137+
assert_type(np.copysign(f8, AR_f8), npt.NDArray[np.float64])
138+
assert_type(np.copysign(AR_f8, AR_f8), npt.NDArray[np.float64])
139+
assert_type(np.copysign(AR_f8, AR_f8, out=AR_f8), npt.NDArray[np.float64])
140+
141+
assert_type(np.heaviside(f8, f8), np.float64)
142+
assert_type(np.heaviside(AR_f8, f8), npt.NDArray[np.float64])
143+
assert_type(np.heaviside(AR_f8, f8, out=AR_f8), npt.NDArray[np.float64])
144+
assert_type(np.heaviside(f8, AR_f8), npt.NDArray[np.float64])
145+
assert_type(np.heaviside(AR_f8, AR_f8), npt.NDArray[np.float64])
146+
assert_type(np.heaviside(AR_f8, AR_f8, out=AR_f8), npt.NDArray[np.float64])
147+
148+
assert_type(np.logaddexp(f8, f8), np.float64)
149+
assert_type(np.logaddexp(AR_f8, f8), npt.NDArray[np.float64])
150+
assert_type(np.logaddexp(AR_f8, f8, out=AR_f8), npt.NDArray[np.float64])
151+
assert_type(np.logaddexp(f8, AR_f8), npt.NDArray[np.float64])
152+
assert_type(np.logaddexp(AR_f8, AR_f8), npt.NDArray[np.float64])
153+
assert_type(np.logaddexp(AR_f8, AR_f8, out=AR_f8), npt.NDArray[np.float64])
154+
155+
assert_type(np.logaddexp2(f8, f8), np.float64)
156+
assert_type(np.logaddexp2(AR_f8, f8), npt.NDArray[np.float64])
157+
assert_type(np.logaddexp2(AR_f8, f8, out=AR_f8), npt.NDArray[np.float64])
158+
assert_type(np.logaddexp2(f8, AR_f8), npt.NDArray[np.float64])
159+
assert_type(np.logaddexp2(AR_f8, AR_f8), npt.NDArray[np.float64])
160+
assert_type(np.logaddexp2(AR_f8, AR_f8, out=AR_f8), npt.NDArray[np.float64])
161+
162+
assert_type(np.nextafter(f8, f8), np.float64)
163+
assert_type(np.nextafter(AR_f8, f8), npt.NDArray[np.float64])
164+
assert_type(np.nextafter(AR_f8, f8, out=AR_f8), npt.NDArray[np.float64])
165+
assert_type(np.nextafter(f8, AR_f8), npt.NDArray[np.float64])
166+
assert_type(np.nextafter(AR_f8, AR_f8), npt.NDArray[np.float64])
167+
assert_type(np.nextafter(AR_f8, AR_f8, out=AR_f8), npt.NDArray[np.float64])

test/static/reject/ufuncs.pyi

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import numpy.typing as npt
44
i8: np.int64
55
AR_f8: npt.NDArray[np.float64]
66
dt64: np.datetime64
7+
AR_dt64: npt.NDArray[np.datetime64]
78

89
np.sin.nin + "foo" # type: ignore[operator] # pyright: ignore[reportOperatorIssue]
910

@@ -84,3 +85,28 @@ np.rad2deg(AR_f8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ign
8485

8586
np.radians(dt64) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
8687
np.radians(AR_f8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
88+
89+
np.copysign(dt64, dt64) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
90+
np.copysign(i8, i8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportCallIssue,reportArgumentType]
91+
np.copysign(AR_dt64, i8) # type: ignore[arg-type] # pyright: ignore[reportCallIssue,reportArgumentType]
92+
np.copysign(i8, AR_dt64) # type: ignore[arg-type] # pyright: ignore[reportCallIssue,reportArgumentType]
93+
94+
np.heaviside(dt64, dt64) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
95+
np.heaviside(i8, i8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
96+
np.heaviside(dt64, i8) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
97+
np.heaviside(i8, AR_dt64) # type: ignore[arg-type] # pyright: ignore[reportCallIssue,reportArgumentType]
98+
99+
np.logaddexp(dt64, dt64) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
100+
np.logaddexp(i8, i8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
101+
np.logaddexp(dt64, i8) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
102+
np.logaddexp(i8, AR_dt64) # type: ignore[arg-type] # pyright: ignore[reportCallIssue,reportArgumentType]
103+
104+
np.logaddexp2(dt64, dt64) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
105+
np.logaddexp2(i8, i8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
106+
np.logaddexp2(dt64, i8) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
107+
np.logaddexp2(i8, AR_dt64) # type: ignore[arg-type] # pyright: ignore[reportCallIssue,reportArgumentType]
108+
109+
np.nextafter(dt64, dt64) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
110+
np.nextafter(i8, i8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
111+
np.nextafter(dt64, i8) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
112+
np.nextafter(i8, AR_dt64) # type: ignore[arg-type] # pyright: ignore[reportCallIssue,reportArgumentType]

0 commit comments

Comments
 (0)