diff --git a/stdlib/collections/__init__.pyi b/stdlib/collections/__init__.pyi index b9e4f84ec0b6..f170566ea119 100644 --- a/stdlib/collections/__init__.pyi +++ b/stdlib/collections/__init__.pyi @@ -2,7 +2,7 @@ import sys from _collections_abc import dict_items, dict_keys, dict_values from _typeshed import SupportsItems, SupportsKeysAndGetItem, SupportsRichComparison, SupportsRichComparisonT from types import GenericAlias -from typing import Any, ClassVar, Generic, NoReturn, SupportsIndex, TypeVar, final, overload +from typing import Any, ClassVar, NoReturn, SupportsIndex, TypeVar, final, overload from typing_extensions import Self if sys.version_info >= (3, 10): @@ -31,6 +31,7 @@ _KT = TypeVar("_KT") _VT = TypeVar("_VT") _KT_co = TypeVar("_KT_co", covariant=True) _VT_co = TypeVar("_VT_co", covariant=True) +_C = TypeVar("_C", default=int) # namedtuple is special-cased in the type checker; the initializer is ignored. def namedtuple( @@ -268,55 +269,55 @@ class deque(MutableSequence[_T]): def __eq__(self, value: object, /) -> bool: ... def __class_getitem__(cls, item: Any, /) -> GenericAlias: ... -class Counter(dict[_T, int], Generic[_T]): +class Counter(dict[_T, _C]): @overload def __init__(self, iterable: None = None, /) -> None: ... @overload - def __init__(self: Counter[str], iterable: None = None, /, **kwargs: int) -> None: ... + def __init__(self: Counter[str], iterable: None = None, /, **kwargs: _C) -> None: ... @overload - def __init__(self, mapping: SupportsKeysAndGetItem[_T, int], /) -> None: ... + def __init__(self, mapping: SupportsKeysAndGetItem[_T, _C], /) -> None: ... @overload def __init__(self, iterable: Iterable[_T], /) -> None: ... def copy(self) -> Self: ... def elements(self) -> Iterator[_T]: ... def most_common(self, n: int | None = None) -> list[tuple[_T, int]]: ... @classmethod - def fromkeys(cls, iterable: Any, v: int | None = None) -> NoReturn: ... # type: ignore[override] + def fromkeys(cls, iterable: Any, v: _C | None = None) -> NoReturn: ... # type: ignore[override] @overload def subtract(self, iterable: None = None, /) -> None: ... @overload - def subtract(self, mapping: Mapping[_T, int], /) -> None: ... + def subtract(self, mapping: Mapping[_T, _C], /) -> None: ... @overload def subtract(self, iterable: Iterable[_T], /) -> None: ... # Unlike dict.update(), use Mapping instead of SupportsKeysAndGetItem for the first overload # (source code does an `isinstance(other, Mapping)` check) # # The second overload is also deliberately different to dict.update() - # (if it were `Iterable[_T] | Iterable[tuple[_T, int]]`, + # (if it were `Iterable[_T] | Iterable[tuple[_T, _C]]`, # the tuples would be added as keys, breaking type safety) @overload # type: ignore[override] - def update(self, m: Mapping[_T, int], /, **kwargs: int) -> None: ... + def update(self, m: Mapping[_T, _C], /, **kwargs: _C) -> None: ... @overload - def update(self, iterable: Iterable[_T], /, **kwargs: int) -> None: ... + def update(self, iterable: Iterable[_T], /, **kwargs: _C) -> None: ... @overload - def update(self, iterable: None = None, /, **kwargs: int) -> None: ... - def __missing__(self, key: _T) -> int: ... + def update(self, iterable: None = None, /, **kwargs: _C) -> None: ... + def __missing__(self, key: _T) -> _C: ... def __delitem__(self, elem: object) -> None: ... if sys.version_info >= (3, 10): def __eq__(self, other: object) -> bool: ... def __ne__(self, other: object) -> bool: ... - def __add__(self, other: Counter[_S]) -> Counter[_T | _S]: ... - def __sub__(self, other: Counter[_T]) -> Counter[_T]: ... - def __and__(self, other: Counter[_T]) -> Counter[_T]: ... - def __or__(self, other: Counter[_S]) -> Counter[_T | _S]: ... # type: ignore[override] - def __pos__(self) -> Counter[_T]: ... - def __neg__(self) -> Counter[_T]: ... + def __add__(self, other: Counter[_S, _C]) -> Counter[_T | _S, _C]: ... + def __sub__(self, other: Counter[_T, _C]) -> Counter[_T, _C]: ... + def __and__(self, other: Counter[_T, _C]) -> Counter[_T, _C]: ... + def __or__(self, other: Counter[_S, _C]) -> Counter[_T | _S, _C]: ... # type: ignore[override] + def __pos__(self) -> Counter[_T, _C]: ... + def __neg__(self) -> Counter[_T, _C]: ... # several type: ignores because __iadd__ is supposedly incompatible with __add__, etc. - def __iadd__(self, other: SupportsItems[_T, int]) -> Self: ... # type: ignore[misc] - def __isub__(self, other: SupportsItems[_T, int]) -> Self: ... - def __iand__(self, other: SupportsItems[_T, int]) -> Self: ... - def __ior__(self, other: SupportsItems[_T, int]) -> Self: ... # type: ignore[override,misc] + def __iadd__(self, other: SupportsItems[_T, _C]) -> Self: ... # type: ignore[misc] + def __isub__(self, other: SupportsItems[_T, _C]) -> Self: ... + def __iand__(self, other: SupportsItems[_T, _C]) -> Self: ... + def __ior__(self, other: SupportsItems[_T, _C]) -> Self: ... # type: ignore[override,misc] if sys.version_info >= (3, 10): def total(self) -> int: ... def __le__(self, other: Counter[Any]) -> bool: ... diff --git a/test_cases/stdlib/collections/check_counter.py b/test_cases/stdlib/collections/check_counter.py new file mode 100644 index 000000000000..581ba67b98e4 --- /dev/null +++ b/test_cases/stdlib/collections/check_counter.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from collections import Counter +from typing import Any, cast +from typing_extensions import assert_type + + +class Foo: ... + + +# Test the constructor +# mypy derives Never for the first type argument while, pyright derives Unknown +assert_type(Counter(), "Counter[Any, int]") +assert_type(Counter(foo=42.2), "Counter[str, float]") +assert_type(Counter({42: "bar"}), "Counter[int, str]") +assert_type(Counter([1, 2, 3]), "Counter[int, int]") + +int_c: Counter[str] = Counter() +assert_type(int_c, "Counter[str, int]") +assert_type(int_c["a"], int) +int_c["a"] = 1 +int_c["a"] += 3 +int_c["a"] += 3.5 # type: ignore + +float_c = Counter(foo=42.2) +assert_type(float_c, "Counter[str, float]") +assert_type(float_c["a"], float) +float_c["a"] = 1.0 +float_c["a"] += 3.0 +float_c["a"] += 42 +float_c["a"] += "42" # type: ignore + +custom_c = cast("Counter[str, Foo]", Counter()) +assert_type(custom_c, "Counter[str, Foo]") +assert_type(custom_c["a"], Foo) +custom_c["a"] = Foo() +custom_c["a"] += Foo() # type: ignore +custom_c["a"] += 42 # type: ignore