From cc3f5c2e54d3bce203ac510501120aa94c25bd1b Mon Sep 17 00:00:00 2001 From: sobolevn Date: Tue, 23 Aug 2022 17:10:09 +0300 Subject: [PATCH 1/2] Mypy now treats classes with `__getitem__` as iterable --- mypy/checker.py | 45 +++++--- mypy/checkexpr.py | 18 ++- mypy/checkmember.py | 6 +- mypy/constraints.py | 3 +- mypy/subtypes.py | 58 +++++++++- test-data/unit/check-protocols.test | 163 ++++++++++++++++++++++++++++ test-data/unit/deps-statements.test | 25 ++++- test-data/unit/fine-grained.test | 45 ++++++++ test-data/unit/pythoneval.test | 12 ++ 9 files changed, 337 insertions(+), 38 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 076f9e3763d9..891484e8ffa0 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3429,9 +3429,9 @@ def type_is_iterable(self, type: Type) -> bool: type = get_proper_type(type) if isinstance(type, CallableType) and type.is_type_obj(): type = type.fallback - return is_subtype( - type, self.named_generic_type("typing.Iterable", [AnyType(TypeOfAny.special_form)]) - ) + with self.msg.filter_errors() as iter_errors: + self.analyze_iterable_item_type(TempNode(type)) + return not iter_errors.has_new_errors() def check_multi_assignment_from_iterable( self, @@ -4247,15 +4247,36 @@ def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]: """Analyse iterable expression and return iterator and iterator item types.""" echk = self.expr_checker iterable = get_proper_type(echk.accept(expr)) - iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], expr)[0] + # We first try to find `__iter__` magic method. + # If it is present, we go on with it. + # But, python also support iterables with just `__getitem__(index) -> Any` defined. + # So, we check it in case `__iter__` is missing. + with self.msg.filter_errors(save_filtered_errors=True) as iter_errors: + # We save original error to show it later if `__getitem__` is also missing. + iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], expr)[0] + if iter_errors.has_new_errors(): + # `__iter__` is missing, try `__getattr__`: + arg = self.temp_node(AnyType(TypeOfAny.implementation_artifact), expr) + with self.msg.filter_errors() as getitem_errors: + getitem_type = echk.check_method_call_by_name( + "__getitem__", iterable, [arg], [nodes.ARG_POS], expr + )[0] + if getitem_errors.has_new_errors(): # Both are missing. + self.msg.add_errors(iter_errors.filtered_errors()) + return AnyType(TypeOfAny.from_error), AnyType(TypeOfAny.from_error) + else: + # We found just `__getitem__`, it does not follow the same + # semantics as `__iter__`, so: just return what we found. + return self.named_generic_type("typing.Iterator", [getitem_type]), getitem_type + + # We found `__iter__`, let's analyze its return type: if isinstance(iterable, TupleType): joined: Type = UninhabitedType() for item in iterable.items: joined = join_types(joined, item) return iterator, joined - else: - # Non-tuple iterable. + else: # Non-tuple iterable. return iterator, echk.check_method_call_by_name("__next__", iterator, [], [], expr)[0] def analyze_container_item_type(self, typ: Type) -> Type | None: @@ -5986,15 +6007,9 @@ def iterable_item_type(self, instance: Instance) -> Type: # This relies on 'map_instance_to_supertype' returning 'Iterable[Any]' # in case there is no explicit base class. return item_type - # Try also structural typing. - iter_type = get_proper_type(find_member("__iter__", instance, instance, is_operator=True)) - if iter_type and isinstance(iter_type, CallableType): - ret_type = get_proper_type(iter_type.ret_type) - if isinstance(ret_type, Instance): - iterator = map_instance_to_supertype( - ret_type, self.lookup_typeinfo("typing.Iterator") - ) - item_type = iterator.args[0] + + # Try also structural typing: including `__iter__` and `__getitem__`. + _, item_type = self.analyze_iterable_item_type(TempNode(instance)) return item_type def function_type(self, func: FuncBase) -> FunctionLike: diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 825230c227d9..ac7662ddff19 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2998,13 +2998,14 @@ def check_method_call_by_name( method, base_type, context, - False, - False, - True, - self.msg, + is_lvalue=False, + is_super=False, + is_operator=True, + msg=self.msg, original_type=original_type, chk=self.chk, in_literal_context=self.is_literal_context(), + suggest_awaitable=False, ) return self.check_method_call(method, base_type, method_type, args, arg_kinds, context) @@ -4806,13 +4807,8 @@ def visit_yield_from_expr(self, e: YieldFromExpr, allow_none_return: bool = Fals if is_async_def(subexpr_type) and not has_coroutine_decorator(return_type): self.chk.msg.yield_from_invalid_operand_type(subexpr_type, e) - any_type = AnyType(TypeOfAny.special_form) - generic_generator_type = self.chk.named_generic_type( - "typing.Generator", [any_type, any_type, any_type] - ) - iter_type, _ = self.check_method_call_by_name( - "__iter__", subexpr_type, [], [], context=generic_generator_type - ) + iter_type, _ = self.chk.analyze_iterable_item_type(TempNode(subexpr_type)) + iter_type = get_proper_type(iter_type) else: if not (is_async_def(subexpr_type) and has_coroutine_decorator(return_type)): self.chk.msg.yield_from_invalid_operand_type(subexpr_type, e) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 3be961ee9fdc..428a19487903 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -90,6 +90,7 @@ def __init__( chk: mypy.checker.TypeChecker, self_type: Type | None, module_symbol_table: SymbolTable | None = None, + suggest_awaitable: bool = True, ) -> None: self.is_lvalue = is_lvalue self.is_super = is_super @@ -100,6 +101,7 @@ def __init__( self.msg = msg self.chk = chk self.module_symbol_table = module_symbol_table + self.suggest_awaitable = suggest_awaitable def named_type(self, name: str) -> Instance: return self.chk.named_type(name) @@ -149,6 +151,7 @@ def analyze_member_access( in_literal_context: bool = False, self_type: Type | None = None, module_symbol_table: SymbolTable | None = None, + suggest_awaitable: bool = True, ) -> Type: """Return the type of attribute 'name' of 'typ'. @@ -183,6 +186,7 @@ def analyze_member_access( chk=chk, self_type=self_type, module_symbol_table=module_symbol_table, + suggest_awaitable=suggest_awaitable, ) result = _analyze_member_access(name, typ, mx, override_info) possible_literal = get_proper_type(result) @@ -260,7 +264,7 @@ def report_missing_attribute( override_info: TypeInfo | None = None, ) -> Type: res_type = mx.msg.has_no_attr(original_type, typ, name, mx.context, mx.module_symbol_table) - if may_be_awaitable_attribute(name, typ, mx, override_info): + if mx.suggest_awaitable and may_be_awaitable_attribute(name, typ, mx, override_info): mx.msg.possible_missing_await(mx.context) return res_type diff --git a/mypy/constraints.py b/mypy/constraints.py index 9e28ce503b6c..19b2bc32e263 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -750,8 +750,7 @@ def infer_constraints_from_protocol_members( """ res = [] for member in protocol.type.protocol_members: - inst = mypy.subtypes.find_member(member, instance, subtype) - temp = mypy.subtypes.find_member(member, template, subtype) + inst, temp = mypy.subtypes.find_members(member, instance, template, subtype) if inst is None or temp is None: return [] # See #11020 # The above is safe since at this point we know that 'instance' is a subtype diff --git a/mypy/subtypes.py b/mypy/subtypes.py index a7ff37b8a62f..3f17c52e6dd9 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import contextmanager -from typing import Any, Callable, Iterator, List, TypeVar, cast +from typing import Any, Callable, Iterator, List, Tuple, TypeVar, cast from typing_extensions import Final, TypeAlias as _TypeAlias import mypy.applytype @@ -14,6 +14,7 @@ # Circular import; done in the function instead. # import mypy.solve from mypy.nodes import ( + ARG_POS, ARG_STAR, ARG_STAR2, CONTRAVARIANT, @@ -956,9 +957,10 @@ def f(self) -> A: ... ignore_names = member != "__call__" # __call__ can be passed kwargs # The third argument below indicates to what self type is bound. # We always bind self to the subtype. (Similarly to nominal types). - supertype = get_proper_type(find_member(member, right, left)) + supertype, subtype = find_members(member, right, left, left) + supertype = get_proper_type(supertype) assert supertype is not None - subtype = get_proper_type(find_member(member, left, left)) + subtype = get_proper_type(subtype) # Useful for debugging: # print(member, 'of', left, 'has type', subtype) # print(member, 'of', right, 'has type', supertype) @@ -1012,6 +1014,56 @@ def f(self) -> A: ... return True +def find_members( + name: str, supertype: Instance, subtype: Instance, context: Type +) -> Tuple[Type | None, Type | None]: + """Find types of member by name for two instances. + + We do it with respect to some special cases, like `Iterable` and `__geitem__`. + """ + if name == "__iter__": + # So, this is a special case: old-style iterbale protocol + # must be supported even without explicit `__iter__` method. + # Because all types with `__geitem__` defined have default `__iter__` + # implementation. See #2220 + # First, we need to find which is one actually `Iterable`: + if is_named_instance(supertype, "typing.Iterable"): + left, right = _iterable_special_member(supertype, subtype, context) + if left is not None and right is not None: + return left, right + elif is_named_instance(subtype, "typing.Iterable"): + left, right = _iterable_special_member(subtype, supertype, context) + if left is not None and right is not None: + return right, left + + # This is not a special case. + # Falling back to regular `find_member` call: + return (find_member(name, supertype, context), find_member(name, subtype, context)) + + +def _iterable_special_member( + iterable: Instance, candidate: Instance, context: Type +) -> Tuple[Type | None, Type | None]: + name = "__iter__" + iterable_method = get_proper_type(find_member(name, iterable, context)) + candidate_method = get_proper_type(find_member("__getitem__", candidate, context)) + if isinstance(iterable_method, CallableType) and isinstance( + (ret := get_proper_type(iterable_method.ret_type)), Instance + ): + # We need to transform + # `__iter__() -> Iterable[ret]` into + # `__getitem__(Any) -> ret` + iterable_method = iterable_method.copy_modified( + arg_names=[None], + arg_types=[AnyType(TypeOfAny.implementation_artifact)], + arg_kinds=[ARG_POS], + ret_type=ret.args[0], + name="__getitem__", + ) + return (iterable_method, candidate_method) + return None, None + + def find_member( name: str, itype: Instance, subtype: Type, is_operator: bool = False ) -> Type | None: diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test index 9be657257fe1..fa1033d93b77 100644 --- a/test-data/unit/check-protocols.test +++ b/test-data/unit/check-protocols.test @@ -2032,6 +2032,158 @@ reveal_type(list(b for b in B())) # N: Revealed type is "builtins.list[__main__ reveal_type(list(B())) # N: Revealed type is "builtins.list[__main__.B]" [builtins fixtures/list.pyi] +[case testOldStyleIterableOnClass] +from typing import Tuple, TypeVar, Generic + +class CorrectGetItem: + def __getitem__(self, arg: int) -> str: pass +class CorrectChild(CorrectGetItem): pass + +T = TypeVar('T') +class WithGeneric(Generic[T]): + def __getitem__(self, arg: int) -> T: pass + +class ReturnsTuple: + def __getitem__(self, arg: int) -> Tuple[int, str]: pass + +# Despite the fact that `arg` must have `int` type due to the spec, +# we allow other signatures right now. +class WrongGetItemSig: + def __getitem__(self, arg: str) -> str: pass + +reveal_type(list(a for a in CorrectGetItem())) # N: Revealed type is "builtins.list[builtins.str]" +reveal_type(list(a for a in CorrectChild())) # N: Revealed type is "builtins.list[builtins.str]" + +x: WithGeneric[int] +reveal_type(list(a for a in x)) # N: Revealed type is "builtins.list[builtins.int]" + +reveal_type(list(a for a in ReturnsTuple())) # N: Revealed type is "builtins.list[Tuple[builtins.int, builtins.str]]" + +reveal_type(list(a for a in WrongGetItemSig())) # N: Revealed type is "builtins.list[builtins.str]" +[builtins fixtures/list.pyi] + +[case testOldIterableTypeAliasAndFor] +from typing import Iterator + +KeyId = str + +class A: + def __getitem__(self, arg: int) -> KeyId: pass + +for a in A(): + reveal_type(a) # N: Revealed type is "builtins.str" +[builtins fixtures/list.pyi] + +[case testIterableAndOldIterableProtocolsOnClass] +from typing import Iterator + +class A: + def __iter__(self) -> Iterator[str]: pass + def __getitem__(self, arg: int) -> int: pass +class B(A): pass + +reveal_type(list(a for a in A())) # N: Revealed type is "builtins.list[builtins.str]" +reveal_type(list(b for b in B())) # N: Revealed type is "builtins.list[builtins.str]" +[builtins fixtures/list.pyi] + +[case testOldStyleIterableProtocolYieldFrom] +from typing import Iterator +class A: + def __getitem__(self, arg: int) -> int: pass +def f() -> Iterator[str]: + yield from A() # E: Incompatible types in "yield from" (actual type "int", expected type "str") +[builtins fixtures/list.pyi] + +[case testOldStyleIterableProtocolUnpack] +class A: + def __getitem__(self, arg: int) -> str: pass + +a, b = A() +reveal_type(a) # N: Revealed type is "builtins.str" +reveal_type(b) # N: Revealed type is "builtins.str" +[builtins fixtures/list.pyi] + +[case testGetItemProtocolJustInCase] +from typing import Protocol, Generic, TypeVar + +TCO = TypeVar('TCO', covariant=True) +T = TypeVar('T') + +class GetItemProto(Protocol[TCO]): + def __getitem__(self, arg: int) -> TCO: pass + +def getitem(p: GetItemProto[T]) -> T: + return p[0] + +class Regular: + def __getitem__(self, arg: int) -> str: pass +class WithGeneric(Generic[T]): + def __getitem__(self, arg: int) -> T: pass +class DifferentSig: + def __getitem__(self, arg: str) -> str: pass + +reveal_type(getitem(Regular())) # N: Revealed type is "builtins.str" + +x: WithGeneric[str] +reveal_type(getitem(x)) # N: Revealed type is "builtins.str" + +getitem(DifferentSig()) # E: Argument 1 to "getitem" has incompatible type "DifferentSig"; expected "GetItemProto[]" +[builtins fixtures/list.pyi] + +[case testIterableAndGetItemAreCompatible] +from typing import Iterable, TypeVar, Generic + +T = TypeVar('T') + +def some_generic(i: Iterable[T]) -> T: pass +def some_regular(i: Iterable[int]): pass + +class Regular: + def __getitem__(self, arg: int) -> int: pass +class WithGeneric(Generic[T]): + def __getitem__(self, arg: int) -> T: pass + +reveal_type(some_generic(Regular())) # N: Revealed type is "builtins.int" +x: WithGeneric[int] +reveal_type(some_generic(x)) # N: Revealed type is "builtins.int" + +some_regular(Regular()) +some_regular(x) +y: WithGeneric[str] +some_regular(y) # E: Argument 1 to "some_regular" has incompatible type "WithGeneric[str]"; expected "Iterable[int]" +[builtins fixtures/list.pyi] + +[case testIterableAndGetItemProtocolsAreNotCompatible] +from typing import Iterable, Protocol + +def expects_iterable(i: Iterable[int]): pass + +class GetItem(Protocol): + def __getitem__(self, arg: int) -> int: pass + +def expects_getitem(i: GetItem): pass + +x: GetItem +expects_iterable(x) # E: Argument 1 to "expects_iterable" has incompatible type "GetItem"; expected "Iterable[int]" + +y: Iterable[int] +expects_getitem(y) # E: Argument 1 to "expects_getitem" has incompatible type "Iterable[int]"; expected "GetItem" +[builtins fixtures/list.pyi] + +[case testCustomIterableAndGetItemProtocolsAreNotCompatible] +from typing import Protocol, Iterator + +class MyIterable(Protocol): + def __iter__(self) -> Iterator[int]: pass + +class Regular: + def __getitem__(self, arg: int) -> int: pass + +def some(i: MyIterable): pass + +some(Regular()) # E: Argument 1 to "some" has incompatible type "Regular"; expected "MyIterable" +[builtins fixtures/list.pyi] + [case testIterableProtocolOnMetaclass] from typing import TypeVar, Iterator, Type T = TypeVar('T') @@ -2049,6 +2201,17 @@ reveal_type(list(c for c in C)) # N: Revealed type is "builtins.list[__main__.C reveal_type(list(C)) # N: Revealed type is "builtins.list[__main__.C]" [builtins fixtures/list.pyi] +[case testOldStyleIterableProtocolOnMetaclass] +class EMeta(type): + def __getitem__(self, arg: int) -> str: pass + +class E(metaclass=EMeta): pass +class C(E): pass + +reveal_type(list(e for e in E)) # N: Revealed type is "builtins.list[builtins.str]" +reveal_type(list(c for c in C)) # N: Revealed type is "builtins.list[builtins.str]" +[builtins fixtures/list.pyi] + [case testClassesGetattrWithProtocols] from typing import Protocol diff --git a/test-data/unit/deps-statements.test b/test-data/unit/deps-statements.test index a67f9c762009..70d349e87bb0 100644 --- a/test-data/unit/deps-statements.test +++ b/test-data/unit/deps-statements.test @@ -256,8 +256,8 @@ def f() -> None: a: A x, y = a [out] - -> , m.f - -> m.A, m.f, typing.Iterable + -> m.f + -> m.A, m.f [case testMultipleLvalues] class A: @@ -327,8 +327,8 @@ def g() -> None: -> m.A.f, m.g -> m.A, m.g -> m.g - -> , m.g - -> , m.B, m.f, typing.Iterable + -> m.g + -> , m.B, m.f -> m.g [case testNestedSetItem] @@ -449,9 +449,22 @@ def f() -> Iterator[int]: yield from A() [out] -> m.f - -> , m.f + -> m.f -> m.f - -> m.A, m.f, typing.Iterable + -> m.A, m.f + +[case testForAndOldStyleIterable] +class A: + def __getitem__(self, arg: int) -> int: pass + +for a in A(): pass +[out] + -> m + -> m + -> m + -> m + -> m, m.A + -> m [case testFunctionDecorator] from typing import Callable diff --git a/test-data/unit/fine-grained.test b/test-data/unit/fine-grained.test index 3a054e8fcfe5..bc7d5950a921 100644 --- a/test-data/unit/fine-grained.test +++ b/test-data/unit/fine-grained.test @@ -6956,6 +6956,51 @@ class B: == main:3: error: "A" has no attribute "__iter__" (not iterable) +[case testWeAreCarefulWithOldIterProtocol] +import a +x: a.A +for i in x: + pass +[file a.py] +class A: + def __getitem__(self, arg: int) -> str: + pass +[file a.py.2] +class A: + pass +[out] +== +main:3: error: "A" has no attribute "__iter__" (not iterable) + +[case testWeAreCarefulWithIterProtocolSwitches] +import a +x: a.A +for i in x: + reveal_type(i) + +y: a.B +for j in y: + reveal_type(j) +[file a.py] +from typing import Iterator +class A: + def __iter__(self) -> Iterator[int]: pass +class B: + def __getitem__(self, arg: int) -> str: pass +[file a.py.2] +from typing import Iterator +class A: + def __getitem__(self, arg: int) -> str: pass +class B: + def __iter__(self) -> Iterator[int]: pass +[out] +main:4: note: Revealed type is "builtins.int" +main:8: note: Revealed type is "builtins.str" +== +main:4: note: Revealed type is "builtins.str" +main:8: note: Revealed type is "builtins.int" + + [case testOverloadsSimpleFrom] import a [file a.py] diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index d7d20a923984..f2353bc9b0f9 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -27,6 +27,18 @@ print(list(reversed(A()))) \['c', 'b', 'a'] \['f', 'o', 'o'] +[case testOldIterableAndIterableProtocolAreIncompatible] +from typing import Iterable + +def some(i: Iterable[int]): pass + +class A: + def __getitem__(self, arg: int) -> int: pass + +# We treat `Iterable` and types with `__getitem__` compatible: see #13485 +some(A()) +[out] + [case testIntAndFloatConversion] from typing import SupportsInt, SupportsFloat class A(SupportsInt): From c6b155a87258a8f4c9e7599b6872faf421da07d7 Mon Sep 17 00:00:00 2001 From: sobolevn Date: Sun, 28 Aug 2022 18:47:36 +0300 Subject: [PATCH 2/2] `:=` is not supported on 3.7 --- mypy/subtypes.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 3e74363541c1..96b74e97dc01 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1059,20 +1059,20 @@ def _find_iter( ) -> Tuple[Type | None, Type | None]: iterable_method = get_proper_type(find_member("__iter__", iterable, context)) candidate_method = get_proper_type(find_member("__getitem__", candidate, context)) - if isinstance(iterable_method, CallableType) and isinstance( - (ret := get_proper_type(iterable_method.ret_type)), Instance - ): - # We need to transform - # `__iter__() -> Iterable[ret]` into - # `__getitem__(Any) -> ret` - iterable_method = iterable_method.copy_modified( - arg_names=[None], - arg_types=[AnyType(TypeOfAny.implementation_artifact)], - arg_kinds=[ARG_POS], - ret_type=ret.args[0], - name="__getitem__", - ) - return (iterable_method, candidate_method) + if isinstance(iterable_method, CallableType): + ret = get_proper_type(iterable_method.ret_type) + if isinstance(ret, Instance): + # We need to transform + # `__iter__() -> Iterable[ret]` into + # `__getitem__(Any) -> ret` + iterable_method = iterable_method.copy_modified( + arg_names=[None], + arg_types=[AnyType(TypeOfAny.implementation_artifact)], + arg_kinds=[ARG_POS], + ret_type=ret.args[0], + name="__getitem__", + ) + return (iterable_method, candidate_method) return None, None # First, we need to find which is one actually `Iterable`: