diff --git a/mypy/checker.py b/mypy/checker.py index 106c8e9a0351..e834f7a06600 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3460,9 +3460,9 @@ def type_is_iterable(self, type: Type) -> bool: type = get_proper_type(type) if isinstance(type, CallableType) and type.is_type_obj(): type = type.fallback - return is_subtype( - type, self.named_generic_type("typing.Iterable", [AnyType(TypeOfAny.special_form)]) - ) + with self.msg.filter_errors() as iter_errors: + self.analyze_iterable_item_type(TempNode(type)) + return not iter_errors.has_new_errors() def check_multi_assignment_from_iterable( self, @@ -4278,15 +4278,36 @@ def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]: """Analyse iterable expression and return iterator and iterator item types.""" echk = self.expr_checker iterable = get_proper_type(echk.accept(expr)) - iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], expr)[0] + # We first try to find `__iter__` magic method. + # If it is present, we go on with it. + # But, python also support iterables with just `__getitem__(index) -> Any` defined. + # So, we check it in case `__iter__` is missing. + with self.msg.filter_errors(save_filtered_errors=True) as iter_errors: + # We save original error to show it later if `__getitem__` is also missing. + iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], expr)[0] + if iter_errors.has_new_errors(): + # `__iter__` is missing, try `__getattr__`: + arg = self.temp_node(AnyType(TypeOfAny.implementation_artifact), expr) + with self.msg.filter_errors() as getitem_errors: + getitem_type = echk.check_method_call_by_name( + "__getitem__", iterable, [arg], [nodes.ARG_POS], expr + )[0] + if getitem_errors.has_new_errors(): # Both are missing. + self.msg.add_errors(iter_errors.filtered_errors()) + return AnyType(TypeOfAny.from_error), AnyType(TypeOfAny.from_error) + else: + # We found just `__getitem__`, it does not follow the same + # semantics as `__iter__`, so: just return what we found. + return self.named_generic_type("typing.Iterator", [getitem_type]), getitem_type + + # We found `__iter__`, let's analyze its return type: if isinstance(iterable, TupleType): joined: Type = UninhabitedType() for item in iterable.items: joined = join_types(joined, item) return iterator, joined - else: - # Non-tuple iterable. + else: # Non-tuple iterable. return iterator, echk.check_method_call_by_name("__next__", iterator, [], [], expr)[0] def analyze_container_item_type(self, typ: Type) -> Type | None: @@ -6014,15 +6035,9 @@ def iterable_item_type(self, instance: Instance) -> Type: # This relies on 'map_instance_to_supertype' returning 'Iterable[Any]' # in case there is no explicit base class. return item_type - # Try also structural typing. - iter_type = get_proper_type(find_member("__iter__", instance, instance, is_operator=True)) - if iter_type and isinstance(iter_type, CallableType): - ret_type = get_proper_type(iter_type.ret_type) - if isinstance(ret_type, Instance): - iterator = map_instance_to_supertype( - ret_type, self.lookup_typeinfo("typing.Iterator") - ) - item_type = iterator.args[0] + + # Try also structural typing: including `__iter__` and `__getitem__`. + _, item_type = self.analyze_iterable_item_type(TempNode(instance)) return item_type def function_type(self, func: FuncBase) -> FunctionLike: diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index ad0436ada214..a20942707f14 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3013,13 +3013,14 @@ def check_method_call_by_name( method, base_type, context, - False, - False, - True, - self.msg, + is_lvalue=False, + is_super=False, + is_operator=True, + msg=self.msg, original_type=original_type, chk=self.chk, in_literal_context=self.is_literal_context(), + suggest_awaitable=False, ) return self.check_method_call(method, base_type, method_type, args, arg_kinds, context) @@ -4834,13 +4835,8 @@ def visit_yield_from_expr(self, e: YieldFromExpr, allow_none_return: bool = Fals if is_async_def(subexpr_type) and not has_coroutine_decorator(return_type): self.chk.msg.yield_from_invalid_operand_type(subexpr_type, e) - any_type = AnyType(TypeOfAny.special_form) - generic_generator_type = self.chk.named_generic_type( - "typing.Generator", [any_type, any_type, any_type] - ) - iter_type, _ = self.check_method_call_by_name( - "__iter__", subexpr_type, [], [], context=generic_generator_type - ) + iter_type, _ = self.chk.analyze_iterable_item_type(TempNode(subexpr_type)) + iter_type = get_proper_type(iter_type) else: if not (is_async_def(subexpr_type) and has_coroutine_decorator(return_type)): self.chk.msg.yield_from_invalid_operand_type(subexpr_type, e) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index ea2544442531..c48eb4c59a29 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -90,6 +90,7 @@ def __init__( chk: mypy.checker.TypeChecker, self_type: Type | None, module_symbol_table: SymbolTable | None = None, + suggest_awaitable: bool = True, ) -> None: self.is_lvalue = is_lvalue self.is_super = is_super @@ -100,6 +101,7 @@ def __init__( self.msg = msg self.chk = chk self.module_symbol_table = module_symbol_table + self.suggest_awaitable = suggest_awaitable def named_type(self, name: str) -> Instance: return self.chk.named_type(name) @@ -149,6 +151,7 @@ def analyze_member_access( in_literal_context: bool = False, self_type: Type | None = None, module_symbol_table: SymbolTable | None = None, + suggest_awaitable: bool = True, ) -> Type: """Return the type of attribute 'name' of 'typ'. @@ -183,6 +186,7 @@ def analyze_member_access( chk=chk, self_type=self_type, module_symbol_table=module_symbol_table, + suggest_awaitable=suggest_awaitable, ) result = _analyze_member_access(name, typ, mx, override_info) possible_literal = get_proper_type(result) @@ -258,7 +262,7 @@ def report_missing_attribute( override_info: TypeInfo | None = None, ) -> Type: res_type = mx.msg.has_no_attr(original_type, typ, name, mx.context, mx.module_symbol_table) - if may_be_awaitable_attribute(name, typ, mx, override_info): + if mx.suggest_awaitable and may_be_awaitable_attribute(name, typ, mx, override_info): mx.msg.possible_missing_await(mx.context) return res_type diff --git a/mypy/constraints.py b/mypy/constraints.py index 05bc680230ee..56f41647ba4c 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -784,8 +784,12 @@ def infer_constraints_from_protocol_members( """ res = [] for member in protocol.type.protocol_members: - inst = mypy.subtypes.find_member(member, instance, subtype, class_obj=class_obj) - temp = mypy.subtypes.find_member(member, template, subtype) + if member == "__iter__": + inst, temp = mypy.subtypes.iter_special_member(member, instance, template, subtype) + else: + inst = mypy.subtypes.find_member(member, instance, subtype, class_obj=class_obj) + temp = mypy.subtypes.find_member(member, template, subtype) + if inst is None or temp is None: return [] # See #11020 # The above is safe since at this point we know that 'instance' is a subtype diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 1efdc7985e57..96b74e97dc01 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import contextmanager -from typing import Any, Callable, Iterator, List, TypeVar, cast +from typing import Any, Callable, Iterator, List, Tuple, TypeVar, cast from typing_extensions import Final, TypeAlias as _TypeAlias import mypy.applytype @@ -14,6 +14,7 @@ # Circular import; done in the function instead. # import mypy.solve from mypy.nodes import ( + ARG_POS, ARG_STAR, ARG_STAR2, CONTRAVARIANT, @@ -969,21 +970,20 @@ def f(self) -> A: ... ignore_names = member != "__call__" # __call__ can be passed kwargs # The third argument below indicates to what self type is bound. # We always bind self to the subtype. (Similarly to nominal types). - supertype = get_proper_type(find_member(member, right, left)) - assert supertype is not None + + # TODO: refactor this and `constraints.py` into something more readable if member == "__call__" and class_obj: - # Special case: class objects always have __call__ that is just the constructor. - # TODO: move this helper function to typeops.py? - import mypy.checkmember + supertype, subtype = call_special_member(left, right, left) + elif member == "__iter__": + supertype, subtype = iter_special_member(member, right, left, left) + else: + supertype = find_member(member, right, left) + subtype = find_member(member, left, left, class_obj=class_obj) - def named_type(fullname: str) -> Instance: - return Instance(left.type.mro[-1], []) + supertype = get_proper_type(supertype) + assert supertype is not None + subtype = get_proper_type(subtype) - subtype: ProperType | None = mypy.checkmember.type_object_type( - left.type, named_type - ) - else: - subtype = get_proper_type(find_member(member, left, left, class_obj=class_obj)) # Useful for debugging: # print(member, 'of', left, 'has type', subtype) # print(member, 'of', right, 'has type', supertype) @@ -1042,6 +1042,68 @@ def named_type(fullname: str) -> Instance: return True +def iter_special_member( + name: str, supertype: Instance, subtype: Instance, context: Type +) -> Tuple[Type | None, Type | None]: + """Find types of member by name for two instances. + + We do it with respect to some special cases, like `Iterable` and `__geitem__`. + """ + # So, this is a special case: old-style iterbale protocol + # must be supported even without explicit `__iter__` method. + # Because all types with `__geitem__` defined have default `__iter__` + # implementation. See #2220 + + def _find_iter( + iterable: Instance, candidate: Instance, context: Type + ) -> Tuple[Type | None, Type | None]: + iterable_method = get_proper_type(find_member("__iter__", iterable, context)) + candidate_method = get_proper_type(find_member("__getitem__", candidate, context)) + if isinstance(iterable_method, CallableType): + ret = get_proper_type(iterable_method.ret_type) + if isinstance(ret, Instance): + # We need to transform + # `__iter__() -> Iterable[ret]` into + # `__getitem__(Any) -> ret` + iterable_method = iterable_method.copy_modified( + arg_names=[None], + arg_types=[AnyType(TypeOfAny.implementation_artifact)], + arg_kinds=[ARG_POS], + ret_type=ret.args[0], + name="__getitem__", + ) + return (iterable_method, candidate_method) + return None, None + + # First, we need to find which is one actually `Iterable`: + if is_named_instance(supertype, "typing.Iterable"): + left, right = _find_iter(supertype, subtype, context) + if left is not None and right is not None: + return left, right + elif is_named_instance(subtype, "typing.Iterable"): + left, right = _find_iter(subtype, supertype, context) + if left is not None and right is not None: + return right, left + + # This is not a special case. + # Falling back to regular `find_member` call: + return (find_member(name, supertype, context), find_member(name, subtype, context)) + + +def call_special_member( + left: Instance, right: Instance, context: Instance +) -> Tuple[Type | None, Type | None]: + """Special case: class objects always have __call__ that is just the constructor.""" + # TODO: move this helper function to typeops.py? + import mypy.checkmember + + def named_type(fullname: str) -> Instance: + return Instance(left.type.mro[-1], []) + + subtype = mypy.checkmember.type_object_type(left.type, named_type) + return find_member("__call__", right, context), subtype + + def find_member( name: str, itype: Instance, subtype: Type, is_operator: bool = False, class_obj: bool = False ) -> Type | None: diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test index 8d8598bc358e..bf7badf6d03f 100644 --- a/test-data/unit/check-protocols.test +++ b/test-data/unit/check-protocols.test @@ -2051,6 +2051,158 @@ reveal_type(list(b for b in B())) # N: Revealed type is "builtins.list[__main__ reveal_type(list(B())) # N: Revealed type is "builtins.list[__main__.B]" [builtins fixtures/list.pyi] +[case testOldStyleIterableOnClass] +from typing import Tuple, TypeVar, Generic + +class CorrectGetItem: + def __getitem__(self, arg: int) -> str: pass +class CorrectChild(CorrectGetItem): pass + +T = TypeVar('T') +class WithGeneric(Generic[T]): + def __getitem__(self, arg: int) -> T: pass + +class ReturnsTuple: + def __getitem__(self, arg: int) -> Tuple[int, str]: pass + +# Despite the fact that `arg` must have `int` type due to the spec, +# we allow other signatures right now. +class WrongGetItemSig: + def __getitem__(self, arg: str) -> str: pass + +reveal_type(list(a for a in CorrectGetItem())) # N: Revealed type is "builtins.list[builtins.str]" +reveal_type(list(a for a in CorrectChild())) # N: Revealed type is "builtins.list[builtins.str]" + +x: WithGeneric[int] +reveal_type(list(a for a in x)) # N: Revealed type is "builtins.list[builtins.int]" + +reveal_type(list(a for a in ReturnsTuple())) # N: Revealed type is "builtins.list[Tuple[builtins.int, builtins.str]]" + +reveal_type(list(a for a in WrongGetItemSig())) # N: Revealed type is "builtins.list[builtins.str]" +[builtins fixtures/list.pyi] + +[case testOldIterableTypeAliasAndFor] +from typing import Iterator + +KeyId = str + +class A: + def __getitem__(self, arg: int) -> KeyId: pass + +for a in A(): + reveal_type(a) # N: Revealed type is "builtins.str" +[builtins fixtures/list.pyi] + +[case testIterableAndOldIterableProtocolsOnClass] +from typing import Iterator + +class A: + def __iter__(self) -> Iterator[str]: pass + def __getitem__(self, arg: int) -> int: pass +class B(A): pass + +reveal_type(list(a for a in A())) # N: Revealed type is "builtins.list[builtins.str]" +reveal_type(list(b for b in B())) # N: Revealed type is "builtins.list[builtins.str]" +[builtins fixtures/list.pyi] + +[case testOldStyleIterableProtocolYieldFrom] +from typing import Iterator +class A: + def __getitem__(self, arg: int) -> int: pass +def f() -> Iterator[str]: + yield from A() # E: Incompatible types in "yield from" (actual type "int", expected type "str") +[builtins fixtures/list.pyi] + +[case testOldStyleIterableProtocolUnpack] +class A: + def __getitem__(self, arg: int) -> str: pass + +a, b = A() +reveal_type(a) # N: Revealed type is "builtins.str" +reveal_type(b) # N: Revealed type is "builtins.str" +[builtins fixtures/list.pyi] + +[case testGetItemProtocolJustInCase] +from typing import Protocol, Generic, TypeVar + +TCO = TypeVar('TCO', covariant=True) +T = TypeVar('T') + +class GetItemProto(Protocol[TCO]): + def __getitem__(self, arg: int) -> TCO: pass + +def getitem(p: GetItemProto[T]) -> T: + return p[0] + +class Regular: + def __getitem__(self, arg: int) -> str: pass +class WithGeneric(Generic[T]): + def __getitem__(self, arg: int) -> T: pass +class DifferentSig: + def __getitem__(self, arg: str) -> str: pass + +reveal_type(getitem(Regular())) # N: Revealed type is "builtins.str" + +x: WithGeneric[str] +reveal_type(getitem(x)) # N: Revealed type is "builtins.str" + +getitem(DifferentSig()) # E: Argument 1 to "getitem" has incompatible type "DifferentSig"; expected "GetItemProto[]" +[builtins fixtures/list.pyi] + +[case testIterableAndGetItemAreCompatible] +from typing import Iterable, TypeVar, Generic + +T = TypeVar('T') + +def some_generic(i: Iterable[T]) -> T: pass +def some_regular(i: Iterable[int]): pass + +class Regular: + def __getitem__(self, arg: int) -> int: pass +class WithGeneric(Generic[T]): + def __getitem__(self, arg: int) -> T: pass + +reveal_type(some_generic(Regular())) # N: Revealed type is "builtins.int" +x: WithGeneric[int] +reveal_type(some_generic(x)) # N: Revealed type is "builtins.int" + +some_regular(Regular()) +some_regular(x) +y: WithGeneric[str] +some_regular(y) # E: Argument 1 to "some_regular" has incompatible type "WithGeneric[str]"; expected "Iterable[int]" +[builtins fixtures/list.pyi] + +[case testIterableAndGetItemProtocolsAreNotCompatible] +from typing import Iterable, Protocol + +def expects_iterable(i: Iterable[int]): pass + +class GetItem(Protocol): + def __getitem__(self, arg: int) -> int: pass + +def expects_getitem(i: GetItem): pass + +x: GetItem +expects_iterable(x) # E: Argument 1 to "expects_iterable" has incompatible type "GetItem"; expected "Iterable[int]" + +y: Iterable[int] +expects_getitem(y) # E: Argument 1 to "expects_getitem" has incompatible type "Iterable[int]"; expected "GetItem" +[builtins fixtures/list.pyi] + +[case testCustomIterableAndGetItemProtocolsAreNotCompatible] +from typing import Protocol, Iterator + +class MyIterable(Protocol): + def __iter__(self) -> Iterator[int]: pass + +class Regular: + def __getitem__(self, arg: int) -> int: pass + +def some(i: MyIterable): pass + +some(Regular()) # E: Argument 1 to "some" has incompatible type "Regular"; expected "MyIterable" +[builtins fixtures/list.pyi] + [case testIterableProtocolOnMetaclass] from typing import TypeVar, Iterator, Type T = TypeVar('T') @@ -2068,6 +2220,17 @@ reveal_type(list(c for c in C)) # N: Revealed type is "builtins.list[__main__.C reveal_type(list(C)) # N: Revealed type is "builtins.list[__main__.C]" [builtins fixtures/list.pyi] +[case testOldStyleIterableProtocolOnMetaclass] +class EMeta(type): + def __getitem__(self, arg: int) -> str: pass + +class E(metaclass=EMeta): pass +class C(E): pass + +reveal_type(list(e for e in E)) # N: Revealed type is "builtins.list[builtins.str]" +reveal_type(list(c for c in C)) # N: Revealed type is "builtins.list[builtins.str]" +[builtins fixtures/list.pyi] + [case testClassesGetattrWithProtocols] from typing import Protocol diff --git a/test-data/unit/deps-statements.test b/test-data/unit/deps-statements.test index a67f9c762009..70d349e87bb0 100644 --- a/test-data/unit/deps-statements.test +++ b/test-data/unit/deps-statements.test @@ -256,8 +256,8 @@ def f() -> None: a: A x, y = a [out] - -> , m.f - -> m.A, m.f, typing.Iterable + -> m.f + -> m.A, m.f [case testMultipleLvalues] class A: @@ -327,8 +327,8 @@ def g() -> None: -> m.A.f, m.g -> m.A, m.g -> m.g - -> , m.g - -> , m.B, m.f, typing.Iterable + -> m.g + -> , m.B, m.f -> m.g [case testNestedSetItem] @@ -449,9 +449,22 @@ def f() -> Iterator[int]: yield from A() [out] -> m.f - -> , m.f + -> m.f -> m.f - -> m.A, m.f, typing.Iterable + -> m.A, m.f + +[case testForAndOldStyleIterable] +class A: + def __getitem__(self, arg: int) -> int: pass + +for a in A(): pass +[out] + -> m + -> m + -> m + -> m + -> m, m.A + -> m [case testFunctionDecorator] from typing import Callable diff --git a/test-data/unit/fine-grained.test b/test-data/unit/fine-grained.test index 0e443abc7237..bed9d1b586c4 100644 --- a/test-data/unit/fine-grained.test +++ b/test-data/unit/fine-grained.test @@ -6956,6 +6956,51 @@ class B: == main:3: error: "A" has no attribute "__iter__" (not iterable) +[case testWeAreCarefulWithOldIterProtocol] +import a +x: a.A +for i in x: + pass +[file a.py] +class A: + def __getitem__(self, arg: int) -> str: + pass +[file a.py.2] +class A: + pass +[out] +== +main:3: error: "A" has no attribute "__iter__" (not iterable) + +[case testWeAreCarefulWithIterProtocolSwitches] +import a +x: a.A +for i in x: + reveal_type(i) + +y: a.B +for j in y: + reveal_type(j) +[file a.py] +from typing import Iterator +class A: + def __iter__(self) -> Iterator[int]: pass +class B: + def __getitem__(self, arg: int) -> str: pass +[file a.py.2] +from typing import Iterator +class A: + def __getitem__(self, arg: int) -> str: pass +class B: + def __iter__(self) -> Iterator[int]: pass +[out] +main:4: note: Revealed type is "builtins.int" +main:8: note: Revealed type is "builtins.str" +== +main:4: note: Revealed type is "builtins.str" +main:8: note: Revealed type is "builtins.int" + + [case testOverloadsSimpleFrom] import a [file a.py] diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index d7d20a923984..f2353bc9b0f9 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -27,6 +27,18 @@ print(list(reversed(A()))) \['c', 'b', 'a'] \['f', 'o', 'o'] +[case testOldIterableAndIterableProtocolAreIncompatible] +from typing import Iterable + +def some(i: Iterable[int]): pass + +class A: + def __getitem__(self, arg: int) -> int: pass + +# We treat `Iterable` and types with `__getitem__` compatible: see #13485 +some(A()) +[out] + [case testIntAndFloatConversion] from typing import SupportsInt, SupportsFloat class A(SupportsInt):