Skip to content

typing: accept buffers in typing.IO.write #9084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions stdlib/codecs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class StreamReaderWriter(TextIO):
def __next__(self) -> str: ...
def __iter__(self) -> Self: ...
def write(self, data: str) -> None: ... # type: ignore[override]
def writelines(self, list: Iterable[str]) -> None: ...
def writelines(self, list: Iterable[str]) -> None: ... # type: ignore[override]
def reset(self) -> None: ...
def seek(self, offset: int, whence: int = 0) -> None: ... # type: ignore[override]
def __enter__(self) -> Self: ...
Expand Down Expand Up @@ -272,8 +272,9 @@ class StreamRecoder(BinaryIO):
def readlines(self, sizehint: int | None = None) -> list[bytes]: ...
def __next__(self) -> bytes: ...
def __iter__(self) -> Self: ...
# Base class accepts more types than just bytes
def write(self, data: bytes) -> None: ... # type: ignore[override]
def writelines(self, list: Iterable[bytes]) -> None: ...
def writelines(self, list: Iterable[bytes]) -> None: ... # type: ignore[override]
def reset(self) -> None: ...
def __getattr__(self, name: str) -> Any: ...
def __enter__(self) -> Self: ...
Expand Down
3 changes: 2 additions & 1 deletion stdlib/http/client.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ class HTTPMessage(email.message.Message):

def parse_headers(fp: io.BufferedIOBase, _class: Callable[[], email.message.Message] = ...) -> HTTPMessage: ...

class HTTPResponse(io.BufferedIOBase, BinaryIO):
# superclasses have incompatible write and writelines methods
class HTTPResponse(io.BufferedIOBase, BinaryIO): # type: ignore[misc]
msg: HTTPMessage
headers: HTTPMessage
version: int
Expand Down
14 changes: 8 additions & 6 deletions stdlib/io.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class BufferedIOBase(IOBase):
def read(self, __size: int | None = ...) -> bytes: ...
def read1(self, __size: int = ...) -> bytes: ...

class FileIO(RawIOBase, BinaryIO):
class FileIO(RawIOBase, BinaryIO): # type: ignore[misc]
mode: str
name: FileDescriptorOrPath # type: ignore[assignment]
def __init__(
Expand All @@ -102,7 +102,7 @@ class FileIO(RawIOBase, BinaryIO):
def read(self, __size: int = -1) -> bytes: ...
def __enter__(self) -> Self: ...

class BytesIO(BufferedIOBase, BinaryIO):
class BytesIO(BufferedIOBase, BinaryIO): # type: ignore[misc]
def __init__(self, initial_bytes: ReadableBuffer = ...) -> None: ...
# BytesIO does not contain a "name" field. This workaround is necessary
# to allow BytesIO sub-classes to add this field, as it is defined
Expand All @@ -113,17 +113,18 @@ class BytesIO(BufferedIOBase, BinaryIO):
def getbuffer(self) -> memoryview: ...
def read1(self, __size: int | None = -1) -> bytes: ...

class BufferedReader(BufferedIOBase, BinaryIO):
# superclasses have incompatible write and writelines methods
class BufferedReader(BufferedIOBase, BinaryIO): # type: ignore[misc]
def __enter__(self) -> Self: ...
def __init__(self, raw: RawIOBase, buffer_size: int = ...) -> None: ...
def peek(self, __size: int = 0) -> bytes: ...

class BufferedWriter(BufferedIOBase, BinaryIO):
class BufferedWriter(BufferedIOBase, BinaryIO): # type: ignore[misc]
def __enter__(self) -> Self: ...
def __init__(self, raw: RawIOBase, buffer_size: int = ...) -> None: ...
def write(self, __buffer: ReadableBuffer) -> int: ...

class BufferedRandom(BufferedReader, BufferedWriter):
class BufferedRandom(BufferedReader, BufferedWriter): # type: ignore[misc]
def __enter__(self) -> Self: ...
def seek(self, __target: int, __whence: int = 0) -> int: ... # stubtest needs this

Expand All @@ -144,7 +145,7 @@ class TextIOBase(IOBase):
def readlines(self, __hint: int = -1) -> list[str]: ... # type: ignore[override]
def read(self, __size: int | None = ...) -> str: ...

class TextIOWrapper(TextIOBase, TextIO):
class TextIOWrapper(TextIOBase, TextIO): # type: ignore[misc]
def __init__(
self,
buffer: IO[bytes],
Expand Down Expand Up @@ -176,6 +177,7 @@ class TextIOWrapper(TextIOBase, TextIO):
def __iter__(self) -> Iterator[str]: ... # type: ignore[override]
def __next__(self) -> str: ... # type: ignore[override]
def writelines(self, __lines: Iterable[str]) -> None: ... # type: ignore[override]
def write(self, __s: str) -> int: ... # type: ignore[override]
def readline(self, __size: int = -1) -> str: ... # type: ignore[override]
def readlines(self, __hint: int = -1) -> list[str]: ... # type: ignore[override]
def seek(self, __cookie: int, __whence: int = 0) -> int: ... # stubtest needs this
Expand Down
2 changes: 1 addition & 1 deletion stdlib/lzma.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class LZMACompressor:

class LZMAError(Exception): ...

class LZMAFile(io.BufferedIOBase, IO[bytes]):
class LZMAFile(io.BufferedIOBase, IO[bytes]): # type: ignore[misc] # incomptatible definitions of writelines in the 2 bases
def __init__(
self,
filename: _PathOrFile | None = None,
Expand Down
16 changes: 14 additions & 2 deletions stdlib/tempfile.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import io
import sys
from _typeshed import BytesPath, GenericPath, StrPath, WriteableBuffer
from _typeshed import BytesPath, GenericPath, ReadableBuffer, StrPath, WriteableBuffer
from collections.abc import Iterable, Iterator
from types import TracebackType
from typing import IO, Any, AnyStr, Generic, overload
Expand Down Expand Up @@ -215,8 +215,14 @@ class _TemporaryFileWrapper(Generic[AnyStr], IO[AnyStr]):
def tell(self) -> int: ...
def truncate(self, size: int | None = ...) -> int: ...
def writable(self) -> bool: ...
@overload
def write(self, s: AnyStr) -> int: ...
@overload
def write(self: _TemporaryFileWrapper[bytes], s: ReadableBuffer) -> int: ...
@overload
def writelines(self, lines: Iterable[AnyStr]) -> None: ...
@overload
def writelines(self: _TemporaryFileWrapper[bytes], lines: Iterable[ReadableBuffer]) -> None: ...

if sys.version_info >= (3, 11):
_SpooledTemporaryFileBase = io.IOBase
Expand Down Expand Up @@ -392,8 +398,14 @@ class SpooledTemporaryFile(IO[AnyStr], _SpooledTemporaryFileBase):
def seek(self, offset: int, whence: int = ...) -> int: ...
def tell(self) -> int: ...
def truncate(self, size: int | None = None) -> None: ... # type: ignore[override]
@overload
def write(self, s: AnyStr) -> int: ...
def writelines(self, iterable: Iterable[AnyStr]) -> None: ... # type: ignore[override]
@overload
def write(self: SpooledTemporaryFile[bytes], s: ReadableBuffer) -> int: ...
@overload
def writelines(self, iterable: Iterable[AnyStr]) -> None: ...
@overload
def writelines(self: SpooledTemporaryFile[bytes], iterable: Iterable[ReadableBuffer]) -> None: ...
def __iter__(self) -> Iterator[AnyStr]: ... # type: ignore[override]
# These exist at runtime only on 3.11+.
def readable(self) -> bool: ...
Expand Down
14 changes: 11 additions & 3 deletions stdlib/typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import collections # Needed by aliases like DefaultDict, see mypy issue 2986
import sys
import typing_extensions
from _collections_abc import dict_items, dict_keys, dict_values
from _typeshed import IdentityFunction, Incomplete, SupportsKeysAndGetItem
from _typeshed import IdentityFunction, Incomplete, ReadableBuffer, SupportsKeysAndGetItem
from abc import ABCMeta, abstractmethod
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from re import Match as Match, Pattern as Pattern
Expand Down Expand Up @@ -687,9 +687,17 @@ class IO(Iterator[AnyStr], Generic[AnyStr]):
@abstractmethod
def writable(self) -> bool: ...
@abstractmethod
def write(self, __s: AnyStr) -> int: ...
@overload
def write(self: IO[str], __s: str) -> int: ...
@abstractmethod
@overload
def write(self: IO[bytes], __s: ReadableBuffer) -> int: ...
@abstractmethod
def writelines(self, __lines: Iterable[AnyStr]) -> None: ...
@overload
def writelines(self: IO[str], __lines: Iterable[str]) -> None: ...
@abstractmethod
@overload
def writelines(self: IO[bytes], __lines: Iterable[ReadableBuffer]) -> None: ...
@abstractmethod
def __next__(self) -> AnyStr: ...
@abstractmethod
Expand Down
21 changes: 21 additions & 0 deletions test_cases/stdlib/typing/check_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations

import mmap
from typing import IO, AnyStr


def check_write(io_bytes: IO[bytes], io_str: IO[str], io_anystr: IO[AnyStr], any_str: AnyStr, buf: mmap.mmap) -> None:
io_bytes.write(b"")
io_bytes.write(buf)
io_bytes.write("") # type: ignore
io_bytes.write(any_str) # type: ignore

io_str.write(b"") # type: ignore
io_str.write(buf) # type: ignore
io_str.write("")
io_str.write(any_str) # type: ignore

io_anystr.write(b"") # type: ignore
io_anystr.write(buf) # type: ignore
io_anystr.write("") # type: ignore
io_anystr.write(any_str)