Skip to content

Make MemoryView Generic, make cast accurate #12247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 12, 2024
58 changes: 58 additions & 0 deletions stdlib/@tests/test_cases/builtins/check_memoryview.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

import array
from typing_extensions import assert_type

# Casting to bytes.
buf = b"abcdefg"
view = memoryview(buf).cast("c")
elm = view[0]
assert_type(elm, bytes)
assert_type(view[0:2], memoryview[bytes])

# Casting to a bool.
a = array.array("B", [0, 1, 2, 3])
mv = memoryview(a)
bool_mv = mv.cast("?")
assert_type(bool_mv[0], bool)
assert_type(bool_mv[0:2], memoryview[bool])


# Casting to a signed char.
a = array.array("B", [0, 1, 2, 3])
mv = memoryview(a)
signed_mv = mv.cast("b")
assert_type(signed_mv[0], int)
assert_type(signed_mv[0:2], memoryview[int])

# Casting to a signed short.
a = array.array("B", [0, 1, 2, 3])
mv = memoryview(a)
signed_mv = mv.cast("h")
assert_type(signed_mv[0], int)
assert_type(signed_mv[0:2], memoryview[int])

# Casting to a signed int.
a = array.array("B", [0, 1, 2, 3])
mv = memoryview(a)
signed_mv = mv.cast("i")
assert_type(signed_mv[0], int)
assert_type(signed_mv[0:2], memoryview[int])

# Casting to a signed long.
a = array.array("B", [0, 1, 2, 3])
mv = memoryview(a)
signed_mv = mv.cast("l")
assert_type(signed_mv[0], int)
assert_type(signed_mv[0:2], memoryview[int])

# Casting to a float.
a = array.array("B", [0, 1, 2, 3])
mv = memoryview(a)
float_mv = mv.cast("f")
assert_type(float_mv[0], float)
assert_type(float_mv[0:2], memoryview[float])

# An invalid literal should raise an error.
mv = memoryview(b"abc")
mv.cast("abc") # type: ignore
22 changes: 17 additions & 5 deletions stdlib/builtins.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ if sys.version_info >= (3, 9):
from types import GenericAlias

_T = TypeVar("_T")
_I = TypeVar("_I", default=int)
_T_co = TypeVar("_T_co", covariant=True)
_T_contra = TypeVar("_T_contra", contravariant=True)
_R_co = TypeVar("_R_co", covariant=True)
Expand Down Expand Up @@ -823,8 +824,12 @@ class bytearray(MutableSequence[int]):
def __buffer__(self, flags: int, /) -> memoryview: ...
def __release_buffer__(self, buffer: memoryview, /) -> None: ...

_IntegerFormats: TypeAlias = Literal[
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if we'd prefer these to be inline for the sake of IDEs, seemed harder to maintain so just decided to place them outside.

"b", "B", "@b", "@B", "h", "H", "@h", "@H", "i", "I", "@i", "@I", "l", "L", "@l", "@L", "q", "Q", "@q", "@Q", "P", "@P"
]

@final
class memoryview(Sequence[int]):
class memoryview(Sequence[_I]):
@property
def format(self) -> str: ...
@property
Expand Down Expand Up @@ -854,13 +859,20 @@ class memoryview(Sequence[int]):
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, /
) -> None: ...
def cast(self, format: str, shape: list[int] | tuple[int, ...] = ...) -> memoryview: ...
@overload
def __getitem__(self, key: SupportsIndex | tuple[SupportsIndex, ...], /) -> int: ...
def cast(self, format: Literal["c", "@c"], shape: list[int] | tuple[int, ...] = ...) -> memoryview[bytes]: ...
@overload
def cast(self, format: Literal["f", "@f", "d", "@d"], shape: list[int] | tuple[int, ...] = ...) -> memoryview[float]: ...
@overload
def cast(self, format: Literal["?"], shape: list[int] | tuple[int, ...] = ...) -> memoryview[bool]: ...
@overload
def __getitem__(self, key: slice, /) -> memoryview: ...
def cast(self, format: _IntegerFormats, shape: list[int] | tuple[int, ...] = ...) -> memoryview: ...
@overload
def __getitem__(self, key: SupportsIndex | tuple[SupportsIndex, ...], /) -> _I: ...
@overload
def __getitem__(self, key: slice, /) -> memoryview[_I]: ...
def __contains__(self, x: object, /) -> bool: ...
def __iter__(self) -> Iterator[int]: ...
def __iter__(self) -> Iterator[_I]: ...
def __len__(self) -> int: ...
def __eq__(self, value: object, /) -> bool: ...
def __hash__(self) -> int: ...
Expand Down
Loading