Skip to content

Make email.message.Message generic over the header type #11732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 23, 2024
53 changes: 26 additions & 27 deletions stdlib/email/message.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,33 @@ from email import _ParamsType, _ParamType
from email.charset import Charset
from email.contentmanager import ContentManager
from email.errors import MessageDefect
from email.header import Header
from email.policy import Policy
from typing import Any, Literal, Protocol, TypeVar, overload
from typing import Any, Generic, Literal, Protocol, TypeVar, overload
from typing_extensions import Self, TypeAlias

__all__ = ["Message", "EmailMessage"]

_T = TypeVar("_T")
# Type returned by Policy.header_fetch_parse, often str or Header.
_HeaderT = TypeVar("_HeaderT", default=str)
_HeaderParamT = TypeVar("_HeaderParamT", default=str)
# Represents headers constructed by HeaderRegistry. Those are sub-classes
# of BaseHeader and another header type.
_HeaderRegistryT = TypeVar("_HeaderRegistryT", default=Any)
_HeaderRegistryParamT = TypeVar("_HeaderRegistryParamT", default=Any)

_PayloadType: TypeAlias = Message | str
_EncodedPayloadType: TypeAlias = Message | bytes
_MultipartPayloadType: TypeAlias = list[_PayloadType]
_CharsetType: TypeAlias = Charset | str | None
# Type returned by Policy.header_fetch_parse, often str or Header.
_HeaderType: TypeAlias = Any
# Type accepted by Policy.header_store_parse.
_HeaderTypeParam: TypeAlias = str | Header | Any

class _SupportsEncodeToPayload(Protocol):
def encode(self, encoding: str, /) -> _PayloadType | _MultipartPayloadType | _SupportsDecodeToPayload: ...

class _SupportsDecodeToPayload(Protocol):
def decode(self, encoding: str, errors: str, /) -> _PayloadType | _MultipartPayloadType: ...

# TODO: This class should be generic over the header policy and/or the header
# value types allowed by the policy. This depends on PEP 696 support
# (https://github.com/python/typeshed/issues/11422).
class Message:
class Message(Generic[_HeaderT, _HeaderParamT]):
policy: Policy # undocumented
preamble: str | None
epilogue: str | None
Expand Down Expand Up @@ -70,24 +70,23 @@ class Message:
# Same as `get` with `failobj=None`, but with the expectation that it won't return None in most scenarios
# This is important for protocols using __getitem__, like SupportsKeysAndGetItem
# Morally, the return type should be `AnyOf[_HeaderType, None]`,
# which we could spell as `_HeaderType | Any`,
# *but* `_HeaderType` itself is currently an alias to `Any`...
def __getitem__(self, name: str) -> _HeaderType: ...
def __setitem__(self, name: str, val: _HeaderTypeParam) -> None: ...
# so using "the Any trick" instead.
def __getitem__(self, name: str) -> _HeaderT | Any: ...
def __setitem__(self, name: str, val: _HeaderParamT) -> None: ...
def __delitem__(self, name: str) -> None: ...
def keys(self) -> list[str]: ...
def values(self) -> list[_HeaderType]: ...
def items(self) -> list[tuple[str, _HeaderType]]: ...
def values(self) -> list[_HeaderT]: ...
def items(self) -> list[tuple[str, _HeaderT]]: ...
@overload
def get(self, name: str, failobj: None = None) -> _HeaderType | None: ...
def get(self, name: str, failobj: None = None) -> _HeaderT | None: ...
@overload
def get(self, name: str, failobj: _T) -> _HeaderType | _T: ...
def get(self, name: str, failobj: _T) -> _HeaderT | _T: ...
@overload
def get_all(self, name: str, failobj: None = None) -> list[_HeaderType] | None: ...
def get_all(self, name: str, failobj: None = None) -> list[_HeaderT] | None: ...
@overload
def get_all(self, name: str, failobj: _T) -> list[_HeaderType] | _T: ...
def get_all(self, name: str, failobj: _T) -> list[_HeaderT] | _T: ...
def add_header(self, _name: str, _value: str, **_params: _ParamsType) -> None: ...
def replace_header(self, _name: str, _value: _HeaderTypeParam) -> None: ...
def replace_header(self, _name: str, _value: _HeaderParamT) -> None: ...
def get_content_type(self) -> str: ...
def get_content_maintype(self) -> str: ...
def get_content_subtype(self) -> str: ...
Expand Down Expand Up @@ -141,14 +140,14 @@ class Message:
) -> None: ...
def __init__(self, policy: Policy = ...) -> None: ...
# The following two methods are undocumented, but a source code comment states that they are public API
def set_raw(self, name: str, value: _HeaderTypeParam) -> None: ...
def raw_items(self) -> Iterator[tuple[str, _HeaderType]]: ...
def set_raw(self, name: str, value: _HeaderParamT) -> None: ...
def raw_items(self) -> Iterator[tuple[str, _HeaderT]]: ...

class MIMEPart(Message):
class MIMEPart(Message[_HeaderRegistryT, _HeaderRegistryParamT]):
def __init__(self, policy: Policy | None = None) -> None: ...
def get_body(self, preferencelist: Sequence[str] = ("related", "html", "plain")) -> Message | None: ...
def iter_attachments(self) -> Iterator[Message]: ...
def iter_parts(self) -> Iterator[Message]: ...
def get_body(self, preferencelist: Sequence[str] = ("related", "html", "plain")) -> MIMEPart[_HeaderRegistryT] | None: ...
def iter_attachments(self) -> Iterator[MIMEPart[_HeaderRegistryT]]: ...
def iter_parts(self) -> Iterator[MIMEPart[_HeaderRegistryT]]: ...
def get_content(self, *args: Any, content_manager: ContentManager | None = None, **kw: Any) -> Any: ...
def set_content(self, *args: Any, content_manager: ContentManager | None = None, **kw: Any) -> None: ...
def make_related(self, boundary: str | None = None) -> None: ...
Expand Down
40 changes: 25 additions & 15 deletions stdlib/email/parser.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,34 @@ from collections.abc import Callable
from email.feedparser import BytesFeedParser as BytesFeedParser, FeedParser as FeedParser
from email.message import Message
from email.policy import Policy
from typing import IO
from io import _WrappedBuffer
from typing import Generic, TypeVar, overload

__all__ = ["Parser", "HeaderParser", "BytesParser", "BytesHeaderParser", "FeedParser", "BytesFeedParser"]

class Parser:
def __init__(self, _class: Callable[[], Message] | None = None, *, policy: Policy = ...) -> None: ...
def parse(self, fp: SupportsRead[str], headersonly: bool = False) -> Message: ...
def parsestr(self, text: str, headersonly: bool = False) -> Message: ...
_MessageT = TypeVar("_MessageT", bound=Message, default=Message)

class HeaderParser(Parser):
def parse(self, fp: SupportsRead[str], headersonly: bool = True) -> Message: ...
def parsestr(self, text: str, headersonly: bool = True) -> Message: ...
class Parser(Generic[_MessageT]):
@overload
def __init__(self: Parser[Message[str, str]], _class: None = None, *, policy: Policy = ...) -> None: ...
@overload
def __init__(self, _class: Callable[[], _MessageT], *, policy: Policy = ...) -> None: ...
def parse(self, fp: SupportsRead[str], headersonly: bool = False) -> _MessageT: ...
def parsestr(self, text: str, headersonly: bool = False) -> _MessageT: ...

class BytesParser:
def __init__(self, _class: Callable[[], Message] = ..., *, policy: Policy = ...) -> None: ...
def parse(self, fp: IO[bytes], headersonly: bool = False) -> Message: ...
def parsebytes(self, text: bytes | bytearray, headersonly: bool = False) -> Message: ...
class HeaderParser(Parser[_MessageT]):
def parse(self, fp: SupportsRead[str], headersonly: bool = True) -> _MessageT: ...
def parsestr(self, text: str, headersonly: bool = True) -> _MessageT: ...

class BytesHeaderParser(BytesParser):
def parse(self, fp: IO[bytes], headersonly: bool = True) -> Message: ...
def parsebytes(self, text: bytes | bytearray, headersonly: bool = True) -> Message: ...
class BytesParser(Generic[_MessageT]):
parser: Parser[_MessageT]
@overload
def __init__(self: BytesParser[Message[str, str]], _class: None = None, *, policy: Policy = ...) -> None: ...
@overload
def __init__(self, _class: Callable[[], _MessageT], *, policy: Policy = ...) -> None: ...
def parse(self, fp: _WrappedBuffer, headersonly: bool = False) -> _MessageT: ...
def parsebytes(self, text: bytes | bytearray, headersonly: bool = False) -> _MessageT: ...

class BytesHeaderParser(BytesParser[_MessageT]):
def parse(self, fp: _WrappedBuffer, headersonly: bool = True) -> _MessageT: ...
def parsebytes(self, text: bytes | bytearray, headersonly: bool = True) -> _MessageT: ...
28 changes: 7 additions & 21 deletions stdlib/http/client.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import io
import ssl
import sys
import types
from _typeshed import ReadableBuffer, SupportsRead, WriteableBuffer
from _typeshed import ReadableBuffer, SupportsRead, SupportsReadline, WriteableBuffer
from collections.abc import Callable, Iterable, Iterator, Mapping
from socket import socket
from typing import Any, BinaryIO, TypeVar, overload
Expand Down Expand Up @@ -33,6 +33,7 @@ __all__ = [

_DataType: TypeAlias = SupportsRead[bytes] | Iterable[ReadableBuffer] | ReadableBuffer
_T = TypeVar("_T")
_MessageT = TypeVar("_MessageT", bound=email.message.Message)

HTTP_PORT: int
HTTPS_PORT: int
Expand Down Expand Up @@ -97,28 +98,13 @@ NETWORK_AUTHENTICATION_REQUIRED: int

responses: dict[int, str]

class HTTPMessage(email.message.Message):
class HTTPMessage(email.message.Message[str, str]):
def getallmatchingheaders(self, name: str) -> list[str]: ... # undocumented
# override below all of Message's methods that use `_HeaderType` / `_HeaderTypeParam` with `str`
# `HTTPMessage` breaks the Liskov substitution principle by only intending for `str` headers
# This is easier than making `Message` generic
def __getitem__(self, name: str) -> str | None: ...
def __setitem__(self, name: str, val: str) -> None: ... # type: ignore[override]
def values(self) -> list[str]: ...
def items(self) -> list[tuple[str, str]]: ...
@overload
def get(self, name: str, failobj: None = None) -> str | None: ...
@overload
def get(self, name: str, failobj: _T) -> str | _T: ...
@overload
def get_all(self, name: str, failobj: None = None) -> list[str] | None: ...
@overload
def get_all(self, name: str, failobj: _T) -> list[str] | _T: ...
def replace_header(self, _name: str, _value: str) -> None: ... # type: ignore[override]
def set_raw(self, name: str, value: str) -> None: ... # type: ignore[override]
def raw_items(self) -> Iterator[tuple[str, str]]: ...

def parse_headers(fp: io.BufferedIOBase, _class: Callable[[], email.message.Message] = ...) -> HTTPMessage: ...
@overload
def parse_headers(fp: SupportsReadline[bytes], _class: Callable[[], _MessageT]) -> _MessageT: ...
@overload
def parse_headers(fp: SupportsReadline[bytes]) -> HTTPMessage: ...

class HTTPResponse(io.BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible method definitions in the base classes
msg: HTTPMessage
Expand Down