Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 4316e00

Browse files
authoredJul 12, 2024··
Make MemoryView Generic, make cast accurate (#12247)
1 parent 3b5b642 commit 4316e00

File tree

2 files changed

+75
-5
lines changed

2 files changed

+75
-5
lines changed
 
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from __future__ import annotations
2+
3+
import array
4+
from typing_extensions import assert_type
5+
6+
# Casting to bytes.
7+
buf = b"abcdefg"
8+
view = memoryview(buf).cast("c")
9+
elm = view[0]
10+
assert_type(elm, bytes)
11+
assert_type(view[0:2], memoryview[bytes])
12+
13+
# Casting to a bool.
14+
a = array.array("B", [0, 1, 2, 3])
15+
mv = memoryview(a)
16+
bool_mv = mv.cast("?")
17+
assert_type(bool_mv[0], bool)
18+
assert_type(bool_mv[0:2], memoryview[bool])
19+
20+
21+
# Casting to a signed char.
22+
a = array.array("B", [0, 1, 2, 3])
23+
mv = memoryview(a)
24+
signed_mv = mv.cast("b")
25+
assert_type(signed_mv[0], int)
26+
assert_type(signed_mv[0:2], memoryview[int])
27+
28+
# Casting to a signed short.
29+
a = array.array("B", [0, 1, 2, 3])
30+
mv = memoryview(a)
31+
signed_mv = mv.cast("h")
32+
assert_type(signed_mv[0], int)
33+
assert_type(signed_mv[0:2], memoryview[int])
34+
35+
# Casting to a signed int.
36+
a = array.array("B", [0, 1, 2, 3])
37+
mv = memoryview(a)
38+
signed_mv = mv.cast("i")
39+
assert_type(signed_mv[0], int)
40+
assert_type(signed_mv[0:2], memoryview[int])
41+
42+
# Casting to a signed long.
43+
a = array.array("B", [0, 1, 2, 3])
44+
mv = memoryview(a)
45+
signed_mv = mv.cast("l")
46+
assert_type(signed_mv[0], int)
47+
assert_type(signed_mv[0:2], memoryview[int])
48+
49+
# Casting to a float.
50+
a = array.array("B", [0, 1, 2, 3])
51+
mv = memoryview(a)
52+
float_mv = mv.cast("f")
53+
assert_type(float_mv[0], float)
54+
assert_type(float_mv[0:2], memoryview[float])
55+
56+
# An invalid literal should raise an error.
57+
mv = memoryview(b"abc")
58+
mv.cast("abc") # type: ignore

‎stdlib/builtins.pyi

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ if sys.version_info >= (3, 9):
7575
from types import GenericAlias
7676

7777
_T = TypeVar("_T")
78+
_I = TypeVar("_I", default=int)
7879
_T_co = TypeVar("_T_co", covariant=True)
7980
_T_contra = TypeVar("_T_contra", contravariant=True)
8081
_R_co = TypeVar("_R_co", covariant=True)
@@ -823,8 +824,12 @@ class bytearray(MutableSequence[int]):
823824
def __buffer__(self, flags: int, /) -> memoryview: ...
824825
def __release_buffer__(self, buffer: memoryview, /) -> None: ...
825826

827+
_IntegerFormats: TypeAlias = Literal[
828+
"b", "B", "@b", "@B", "h", "H", "@h", "@H", "i", "I", "@i", "@I", "l", "L", "@l", "@L", "q", "Q", "@q", "@Q", "P", "@P"
829+
]
830+
826831
@final
827-
class memoryview(Sequence[int]):
832+
class memoryview(Sequence[_I]):
828833
@property
829834
def format(self) -> str: ...
830835
@property
@@ -854,13 +859,20 @@ class memoryview(Sequence[int]):
854859
def __exit__(
855860
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, /
856861
) -> None: ...
857-
def cast(self, format: str, shape: list[int] | tuple[int, ...] = ...) -> memoryview: ...
858862
@overload
859-
def __getitem__(self, key: SupportsIndex | tuple[SupportsIndex, ...], /) -> int: ...
863+
def cast(self, format: Literal["c", "@c"], shape: list[int] | tuple[int, ...] = ...) -> memoryview[bytes]: ...
864+
@overload
865+
def cast(self, format: Literal["f", "@f", "d", "@d"], shape: list[int] | tuple[int, ...] = ...) -> memoryview[float]: ...
866+
@overload
867+
def cast(self, format: Literal["?"], shape: list[int] | tuple[int, ...] = ...) -> memoryview[bool]: ...
860868
@overload
861-
def __getitem__(self, key: slice, /) -> memoryview: ...
869+
def cast(self, format: _IntegerFormats, shape: list[int] | tuple[int, ...] = ...) -> memoryview: ...
870+
@overload
871+
def __getitem__(self, key: SupportsIndex | tuple[SupportsIndex, ...], /) -> _I: ...
872+
@overload
873+
def __getitem__(self, key: slice, /) -> memoryview[_I]: ...
862874
def __contains__(self, x: object, /) -> bool: ...
863-
def __iter__(self) -> Iterator[int]: ...
875+
def __iter__(self) -> Iterator[_I]: ...
864876
def __len__(self) -> int: ...
865877
def __eq__(self, value: object, /) -> bool: ...
866878
def __hash__(self) -> int: ...

0 commit comments

Comments
 (0)
Please sign in to comment.