diff --git a/docs/source/class_basics.rst b/docs/source/class_basics.rst index dc778d39424f..e03c299e977c 100644 --- a/docs/source/class_basics.rst +++ b/docs/source/class_basics.rst @@ -151,7 +151,174 @@ concrete. As with normal overrides, a dynamically typed method can implement a statically typed abstract method defined in an abstract base class. +.. _protocol-types: + +Protocols and structural subtyping +********************************** + +.. note:: + + The support for structural subtyping is still experimental. Some features + might be not yet implemented, mypy could pass unsafe code or reject + working code. + +There are two main type systems with respect to subtyping: nominal subtyping +and structural subtyping. The *nominal* subtyping is based on class hierarchy, +so that if class ``D`` inherits from class ``C``, then it is a subtype +of ``C``. This type system is primarily used in mypy since it allows +to produce clear and concise error messages, and since Python provides native +``isinstance()`` checks based on class hierarchy. The *structural* subtyping +however has its own advantages. In this system class ``D`` is a subtype +of class ``C`` if the former has all attributes of the latter with +compatible types. + +This type system is a static equivalent of duck typing, well known by Python +programmers. Mypy provides an opt-in support for structural subtyping via +protocol classes described in this section. +See `PEP 544 `_ for +specification of protocols and structural subtyping in Python. + +User defined protocols +********************** + +To define a protocol class, one must inherit the special +``typing_extensions.Protocol`` class: + +.. code-block:: python + + from typing import Iterable + from typing_extensions import Protocol + + class SupportsClose(Protocol): + def close(self) -> None: + ... + + class Resource: # Note, this class does not have 'SupportsClose' base. + # some methods + def close(self) -> None: + self.resource.release() + + def close_all(things: Iterable[SupportsClose]) -> None: + for thing in things: + thing.close() + + close_all([Resource(), open('some/file')]) # This passes type check + +.. note:: + + The ``Protocol`` base class is currently provided in ``typing_extensions`` + package. When structural subtyping is mature and + `PEP 544 `_ is accepted, + ``Protocol`` will be included in the ``typing`` module. As well, several + types such as ``typing.Sized``, ``typing.Iterable`` etc. will be made + protocols. + +Defining subprotocols +********************* + +Subprotocols are also supported. Existing protocols can be extended +and merged using multiple inheritance. For example: + +.. code-block:: python + + # continuing from previous example + + class SupportsRead(Protocol): + def read(self, amount: int) -> bytes: ... + + class TaggedReadableResource(SupportsClose, SupportsRead, Protocol): + label: str + + class AdvancedResource(Resource): + def __init__(self, label: str) -> None: + self.label = label + def read(self, amount: int) -> bytes: + # some implementation + ... + + resource = None # type: TaggedReadableResource + + # some code + + resource = AdvancedResource('handle with care') # OK + +Note that inheriting from existing protocols does not automatically turn +a subclass into a protocol, it just creates a usual (non-protocol) ABC that +implements given protocols. The ``typing_extensions.Protocol`` base must always +be explicitly present: + +.. code-block:: python + + class NewProtocol(SupportsClose): # This is NOT a protocol + new_attr: int + + class Concrete: + new_attr = None # type: int + def close(self) -> None: + ... + # Below is an error, since nominal subtyping is used by default + x = Concrete() # type: NewProtocol # Error! + +.. note:: + + The `PEP 526 `_ variable + annotations can be used to declare protocol attributes. However, protocols + are also supported on Python 2.7 and Python 3.3+ with the help of type + comments and properties, see + `backwards compatibility in PEP 544 `_. + +Recursive protocols +******************* + +Protocols can be recursive and mutually recursive. This could be useful for +declaring abstract recursive collections such as trees and linked lists: + +.. code-block:: python + + from typing import TypeVar, Optional + from typing_extensions import Protocol + + class TreeLike(Protocol): + value: int + @property + def left(self) -> Optional['TreeLike']: ... + @property + def right(self) -> Optional['TreeLike']: ... + + class SimpleTree: + def __init__(self, value: int) -> None: + self.value = value + self.left: Optional['SimpleTree'] = None + self.right: Optional['SimpleTree'] = None + + root = SimpleTree(0) # type: TreeLike # OK + +Using ``isinstance()`` with protocols +************************************* + +To use a protocol class with ``isinstance()``, one needs to decorate it with +a special ``typing_extensions.runtime`` decorator. It will add support for +basic runtime structural checks: + +.. code-block:: python + + from typing_extensions import Protocol, runtime + + @runtime + class Portable(Protocol): + handles: int + + class Mug: + def __init__(self) -> None: + self.handles = 1 + + mug = Mug() + if isinstance(mug, Portable): + use(mug.handles) # Works statically and at runtime. + .. note:: + ``isinstance()`` is with protocols not completely safe at runtime. + For example, signatures of methods are not checked. The runtime + implementation only checks the presence of all protocol members + in object's MRO. - There are also plans to support more Python-style "duck typing" in - the type system. The details are still open. diff --git a/docs/source/common_issues.rst b/docs/source/common_issues.rst index 0c8b500d8f06..e31d9a7bc4e6 100644 --- a/docs/source/common_issues.rst +++ b/docs/source/common_issues.rst @@ -226,6 +226,48 @@ Possible strategies in such situations are: return x[0] f_good(new_lst) # OK +Covariant subtyping of mutable protocol members is rejected +----------------------------------------------------------- + +Mypy rejects this because this is potentially unsafe. +Consider this example: + +.. code-block:: python + + from typing_extensions import Protocol + + class P(Protocol): + x: float + + def fun(arg: P) -> None: + arg.x = 3.14 + + class C: + x = 42 + c = C() + fun(c) # This is not safe + c.x << 5 # Since this will fail! + +To work around this problem consider whether "mutating" is actually part +of a protocol. If not, then one can use a ``@property`` in +the protocol definition: + +.. code-block:: python + + from typing_extensions import Protocol + + class P(Protocol): + @property + def x(self) -> float: + pass + + def fun(arg: P) -> None: + ... + + class C: + x = 42 + fun(C()) # OK + Declaring a supertype as variable type -------------------------------------- diff --git a/docs/source/faq.rst b/docs/source/faq.rst index 9fd73b42b8a3..b131e3f7e2a2 100644 --- a/docs/source/faq.rst +++ b/docs/source/faq.rst @@ -101,35 +101,38 @@ Is mypy free? Yes. Mypy is free software, and it can also be used for commercial and proprietary projects. Mypy is available under the MIT license. -Why not use structural subtyping? -********************************* +Can I use structural subtyping? +******************************* -Mypy primarily uses `nominal subtyping -`_ instead of +Mypy provides support for both `nominal subtyping +`_ and `structural subtyping -`_. Some argue -that structural subtyping is better suited for languages with duck -typing such as Python. - -Here are some reasons why mypy uses nominal subtyping: +`_. +Support for structural subtyping is considered experimental. +Some argue that structural subtyping is better suited for languages with duck +typing such as Python. Mypy however primarily uses nominal subtyping, +leaving structural subtyping opt-in. Here are some reasons why: 1. It is easy to generate short and informative error messages when using a nominal type system. This is especially important when using type inference. -2. Python supports basically nominal isinstance tests and they are - widely used in programs. It is not clear how to support isinstance - in a purely structural type system while remaining compatible with - Python idioms. +2. Python provides built-in support for nominal ``isinstance()`` tests and + they are widely used in programs. Only limited support for structural + ``isinstance()`` exists for ABCs in ``collections.abc`` and ``typing`` + standard library modules. 3. Many programmers are already familiar with nominal subtyping and it has been successfully used in languages such as Java, C++ and C#. Only few languages use structural subtyping. -However, structural subtyping can also be useful. Structural subtyping -is a likely feature to be added to mypy in the future, even though we -expect that most mypy programs will still primarily use nominal -subtyping. +However, structural subtyping can also be useful. For example, a "public API" +may be more flexible if it is typed with protocols. Also, using protocol types +removes the necessity to explicitly declare implementations of ABCs. +As a rule of thumb, we recommend using nominal classes where possible, and +protocols where necessary. For more details about protocol types and structural +subtyping see :ref:`protocol-types` and +`PEP 544 `_. I like Python and I have no need for static typing ************************************************** diff --git a/docs/source/generics.rst b/docs/source/generics.rst index bd0e0549fd1a..8f8f5355b93e 100644 --- a/docs/source/generics.rst +++ b/docs/source/generics.rst @@ -1,6 +1,8 @@ Generics ======== +.. _generic-classes: + Defining generic classes ************************ @@ -489,6 +491,73 @@ restrict the valid values for the type parameter in the same way. A type variable may not have both a value restriction (see :ref:`type-variable-value-restriction`) and an upper bound. +Generic protocols +***************** + +Generic protocols (see :ref:`protocol-types`) are also supported, generic +protocols mostly follow the normal rules for generic classes, the main +difference is that mypy checks that declared variance of type variables is +compatible with the class definition. Examples: + +.. code-block:: python + + from typing import TypeVar + from typing_extensions import Protocol + + T = TypeVar('T') + + class Box(Protocol[T]): + content: T + + def do_stuff(one: Box[str], other: Box[bytes]) -> None: + ... + + class StringWrapper: + def __init__(self, content: str) -> None: + self.content = content + + class BytesWrapper: + def __init__(self, content: bytes) -> None: + self.content = content + + do_stuff(StringWrapper('one'), BytesWrapper(b'other')) # OK + + x = None # type: Box[float] + y = None # type: Box[int] + x = y # Error, since the protocol 'Box' is invariant. + + class AnotherBox(Protocol[T]): # Error, covariant type variable expected + def content(self) -> T: + ... + + T_co = TypeVar('T_co', covariant=True) + class AnotherBox(Protocol[T_co]): # OK + def content(self) -> T_co: + ... + + ax = None # type: AnotherBox[float] + ay = None # type: AnotherBox[int] + ax = ay # OK for covariant protocols + +See :ref:`variance-of-generics` above for more details on variance. +Generic protocols can be recursive, for example: + +.. code-block:: python + + T = TypeVar('T') + class Linked(Protocol[T]): + val: T + def next(self) -> 'Linked[T]': ... + + class L: + val: int + def next(self) -> 'L': ... + + def last(seq: Linked[T]) -> T: + ... + + result = last(L()) # The inferred type of 'result' is 'int' + .. _declaring-decorators: Declaring decorators diff --git a/mypy/checker.py b/mypy/checker.py index ca9cf379df06..16329871b0bc 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -27,7 +27,7 @@ RefExpr, YieldExpr, BackquoteExpr, Import, ImportFrom, ImportAll, ImportBase, AwaitExpr, PromoteExpr, Node, EnumCallExpr, ARG_POS, MDEF, - CONTRAVARIANT, COVARIANT) + CONTRAVARIANT, COVARIANT, INVARIANT) from mypy import nodes from mypy.typeanal import has_any_from_unimported_type, check_for_explicit_any from mypy.types import ( @@ -44,7 +44,7 @@ from mypy.subtypes import ( is_subtype, is_equivalent, is_proper_subtype, is_more_precise, restrict_subtype_away, is_subtype_ignoring_tvars, is_callable_subtype, - unify_generic_callable, + unify_generic_callable, find_member ) from mypy.maptype import map_instance_to_supertype from mypy.typevars import fill_typevars, has_no_typevars @@ -1147,6 +1147,8 @@ def erase_override(t: Type) -> Type: def visit_class_def(self, defn: ClassDef) -> None: """Type check a class definition.""" typ = defn.info + if typ.is_protocol and typ.defn.type_vars: + self.check_protocol_variance(defn) with self.errors.enter_type(defn.name), self.enter_partial_types(): old_binder = self.binder self.binder = ConditionalTypeBinder() @@ -1158,6 +1160,33 @@ def visit_class_def(self, defn: ClassDef) -> None: # Otherwise we've already found errors; more errors are not useful self.check_multiple_inheritance(typ) + def check_protocol_variance(self, defn: ClassDef) -> None: + """Check that protocol definition is compatible with declared + variances of type variables. + + Note that we also prohibit declaring protocol classes as invariant + if they are actually covariant/contravariant, since this may break + transitivity of subtyping, see PEP 544. + """ + info = defn.info + object_type = Instance(info.mro[-1], []) + tvars = info.defn.type_vars + for i, tvar in enumerate(tvars): + up_args = [object_type if i == j else AnyType(TypeOfAny.special_form) + for j, _ in enumerate(tvars)] + down_args = [UninhabitedType() if i == j else AnyType(TypeOfAny.special_form) + for j, _ in enumerate(tvars)] + up, down = Instance(info, up_args), Instance(info, down_args) + # TODO: add advanced variance checks for recursive protocols + if is_subtype(down, up, ignore_declared_variance=True): + expected = COVARIANT + elif is_subtype(up, down, ignore_declared_variance=True): + expected = CONTRAVARIANT + else: + expected = INVARIANT + if expected != tvar.variance: + self.msg.bad_proto_variance(tvar.variance, tvar.name, expected, defn) + def check_multiple_inheritance(self, typ: TypeInfo) -> None: """Check for multiple inheritance related errors.""" if len(typ.bases) <= 1: @@ -1328,15 +1357,16 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type else: rvalue_type = self.check_simple_assignment(lvalue_type, rvalue, lvalue) - # Special case: only non-abstract classes can be assigned to variables - # with explicit type Type[A]. + # Special case: only non-abstract non-protocol classes can be assigned to + # variables with explicit type Type[A], where A is protocol or abstract. if (isinstance(rvalue_type, CallableType) and rvalue_type.is_type_obj() and - rvalue_type.type_object().is_abstract and + (rvalue_type.type_object().is_abstract or + rvalue_type.type_object().is_protocol) and isinstance(lvalue_type, TypeType) and isinstance(lvalue_type.item, Instance) and - lvalue_type.item.type.is_abstract): - self.fail("Can only assign non-abstract classes" - " to a variable of type '{}'".format(lvalue_type), rvalue) + (lvalue_type.item.type.is_abstract or + lvalue_type.item.type.is_protocol)): + self.msg.concrete_only_assign(lvalue_type, rvalue) return if rvalue_type and infer_lvalue_type: self.binder.assign_type(lvalue, rvalue_type, lvalue_type, False) @@ -2468,6 +2498,13 @@ def check_subtype(self, subtype: Type, supertype: Type, context: Context, self.fail(msg, context) if note_msg: self.note(note_msg, context) + if (isinstance(supertype, Instance) and supertype.type.is_protocol and + isinstance(subtype, (Instance, TupleType, TypedDictType))): + self.msg.report_protocol_problems(subtype, supertype, context) + if isinstance(supertype, CallableType) and isinstance(subtype, Instance): + call = find_member('__call__', subtype, subtype) + if call: + self.msg.note_call(subtype, call, context) return False def contains_none(self, t: Type) -> bool: @@ -2609,9 +2646,9 @@ def warn(self, msg: str, context: Context) -> None: """Produce a warning message.""" self.msg.warn(msg, context) - def note(self, msg: str, context: Context) -> None: + def note(self, msg: str, context: Context, offset: int = 0) -> None: """Produce a note.""" - self.msg.note(msg, context) + self.msg.note(msg, context, offset=offset) def iterable_item_type(self, instance: Instance) -> Type: iterable = map_instance_to_supertype( diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index a26637dfd544..1d50aa749274 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -35,7 +35,7 @@ from mypy import join from mypy.meet import narrow_declared_type from mypy.maptype import map_instance_to_supertype -from mypy.subtypes import is_subtype, is_equivalent +from mypy.subtypes import is_subtype, is_equivalent, find_member from mypy import applytype from mypy import erasetype from mypy.checkmember import analyze_member_access, type_object_type, bind_self @@ -245,6 +245,15 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type: callee_type = self.apply_method_signature_hook( e, callee_type, object_type, signature_hook) ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname, object_type) + if (isinstance(e.callee, RefExpr) and len(e.args) == 2 and + e.callee.fullname in ('builtins.isinstance', 'builtins.issubclass')): + for expr in mypy.checker.flatten(e.args[1]): + tp = self.chk.type_map[expr] + if (isinstance(tp, CallableType) and tp.is_type_obj() and + tp.type_object().is_protocol and + not tp.type_object().runtime_protocol): + self.chk.fail('Only @runtime protocols can be used with' + ' instance and class checks', e) if isinstance(ret_type, UninhabitedType): self.chk.binder.unreachable() if not allow_none_return and isinstance(ret_type, NoneTyp): @@ -502,6 +511,11 @@ def check_call(self, callee: Type, args: List[Expression], self.msg.cannot_instantiate_abstract_class( callee.type_object().name(), type.abstract_attributes, context) + elif (callee.is_type_obj() and callee.type_object().is_protocol + # Exceptions for Type[...] and classmethod first argument + and not callee.from_type_type and not callee.is_classmethod_class): + self.chk.fail('Cannot instantiate protocol class "{}"' + .format(callee.type_object().name()), context) formal_to_actual = map_actuals_to_formals( arg_kinds, arg_names, @@ -1001,19 +1015,28 @@ def check_arg(self, caller_type: Type, original_caller_type: Type, """Check the type of a single argument in a call.""" if isinstance(caller_type, DeletedType): messages.deleted_as_rvalue(caller_type, context) - # Only non-abstract class can be given where Type[...] is expected... + # Only non-abstract non-protocol class can be given where Type[...] is expected... elif (isinstance(caller_type, CallableType) and isinstance(callee_type, TypeType) and - caller_type.is_type_obj() and caller_type.type_object().is_abstract and - isinstance(callee_type.item, Instance) and callee_type.item.type.is_abstract and + caller_type.is_type_obj() and + (caller_type.type_object().is_abstract or caller_type.type_object().is_protocol) and + isinstance(callee_type.item, Instance) and + (callee_type.item.type.is_abstract or callee_type.item.type.is_protocol) and # ...except for classmethod first argument not caller_type.is_classmethod_class): - messages.fail("Only non-abstract class can be given where '{}' is expected" - .format(callee_type), context) + self.msg.concrete_only_call(callee_type, context) elif not is_subtype(caller_type, callee_type): if self.chk.should_suppress_optional_error([caller_type, callee_type]): return messages.incompatible_argument(n, m, callee, original_caller_type, caller_kind, context) + if (isinstance(original_caller_type, (Instance, TupleType, TypedDictType)) and + isinstance(callee_type, Instance) and callee_type.type.is_protocol): + self.msg.report_protocol_problems(original_caller_type, callee_type, context) + if (isinstance(callee_type, CallableType) and + isinstance(original_caller_type, Instance)): + call = find_member('__call__', original_caller_type, original_caller_type) + if call: + self.msg.note_call(original_caller_type, call, context) def overload_call_target(self, arg_types: List[Type], arg_kinds: List[int], arg_names: List[str], @@ -2697,6 +2720,8 @@ def overload_arg_similarity(actual: Type, formal: Type) -> int: # subtyping algorithm if type promotions are possible (e.g., int vs. float). if formal.type in actual.type.mro: return 2 + elif formal.type.is_protocol and is_subtype(actual, erasetype.erase_type(formal)): + return 2 elif actual.type._promote and is_subtype(actual, formal): return 1 else: diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 93e210523c46..ceb9a4ff00ce 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -53,6 +53,8 @@ def analyze_member_access(name: str, the fallback type, for example. original_type is always the type used in the initial call. """ + # TODO: this and following functions share some logic with subtypes.find_member, + # consider refactoring. if isinstance(typ, Instance): if name == '__init__' and not is_super: # Accessing __init__ in statically typed code would compromise diff --git a/mypy/constraints.py b/mypy/constraints.py index 0aa12322bc70..7180dd00f25f 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -305,7 +305,7 @@ def visit_type_var(self, template: TypeVarType) -> List[Constraint]: # Non-leaf types def visit_instance(self, template: Instance) -> List[Constraint]: - actual = self.actual + original_actual = actual = self.actual res = [] # type: List[Constraint] if isinstance(actual, CallableType) and actual.fallback is not None: actual = actual.fallback @@ -313,6 +313,8 @@ def visit_instance(self, template: Instance) -> List[Constraint]: actual = actual.as_anonymous().fallback if isinstance(actual, Instance): instance = actual + # We always try nominal inference if possible, + # it is much faster than the structural one. if (self.direction == SUBTYPE_OF and template.type.has_base(instance.type.fullname())): mapped = map_instance_to_supertype(template, instance.type) @@ -336,6 +338,28 @@ def visit_instance(self, template: Instance) -> List[Constraint]: res.extend(infer_constraints( template.args[j], mapped.args[j], neg_op(self.direction))) return res + if (template.type.is_protocol and self.direction == SUPERTYPE_OF and + # We avoid infinite recursion for structural subtypes by checking + # whether this type already appeared in the inference chain. + # This is a conservative way break the inference cycles. + # It never produces any "false" constraints but gives up soon + # on purely structural inference cycles, see #3829. + not any(is_same_type(template, t) for t in template.type.inferring) and + mypy.subtypes.is_subtype(instance, erase_typevars(template))): + template.type.inferring.append(template) + self.infer_constraints_from_protocol_members(res, instance, template, + original_actual, template) + template.type.inferring.pop() + return res + elif (instance.type.is_protocol and self.direction == SUBTYPE_OF and + # We avoid infinite recursion for structural subtypes also here. + not any(is_same_type(instance, i) for i in instance.type.inferring) and + mypy.subtypes.is_subtype(erase_typevars(template), instance)): + instance.type.inferring.append(instance) + self.infer_constraints_from_protocol_members(res, instance, template, + template, instance) + instance.type.inferring.pop() + return res if isinstance(actual, AnyType): # IDEA: Include both ways, i.e. add negation as well? return self.infer_against_any(template.args, actual) @@ -349,9 +373,36 @@ def visit_instance(self, template: Instance) -> List[Constraint]: cb = infer_constraints(template.args[0], item, SUPERTYPE_OF) res.extend(cb) return res + elif (isinstance(actual, TupleType) and template.type.is_protocol and + self.direction == SUPERTYPE_OF): + if mypy.subtypes.is_subtype(actual.fallback, erase_typevars(template)): + res.extend(infer_constraints(template, actual.fallback, self.direction)) + return res + return [] else: return [] + def infer_constraints_from_protocol_members(self, res: List[Constraint], + instance: Instance, template: Instance, + subtype: Type, protocol: Instance) -> None: + """Infer constraints for situations where either 'template' or 'instance' is a protocol. + + The 'protocol' is the one of two that is an instance of protocol type, 'subtype' + is the type used to bind self during inference. Currently, we just infer constrains for + every protocol member type (both ways for settable members). + """ + for member in protocol.type.protocol_members: + inst = mypy.subtypes.find_member(member, instance, subtype) + temp = mypy.subtypes.find_member(member, template, subtype) + assert inst is not None and temp is not None + # The above is safe since at this point we know that 'instance' is a subtype + # of (erased) 'template', therefore it defines all protocol members + res.extend(infer_constraints(temp, inst, self.direction)) + if (mypy.subtypes.IS_SETTABLE in + mypy.subtypes.get_member_flags(member, protocol.type)): + # Settable members are invariant, add opposite constraints + res.extend(infer_constraints(temp, inst, neg_op(self.direction))) + def visit_callable_type(self, template: CallableType) -> List[Constraint]: if isinstance(self.actual, CallableType): cactual = self.actual @@ -379,6 +430,14 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]: return self.infer_against_overloaded(self.actual, template) elif isinstance(self.actual, TypeType): return infer_constraints(template.ret_type, self.actual.item, self.direction) + elif isinstance(self.actual, Instance): + # Instances with __call__ method defined are considered structural + # subtypes of Callable with a compatible signature. + call = mypy.subtypes.find_member('__call__', self.actual, self.actual) + if call: + return infer_constraints(template, call, self.direction) + else: + return [] else: return [] diff --git a/mypy/errors.py b/mypy/errors.py index 1da550a8e68b..eb6a17efaf3c 100644 --- a/mypy/errors.py +++ b/mypy/errors.py @@ -11,6 +11,7 @@ T = TypeVar('T') +allowed_duplicates = ['@overload', 'Got:', 'Expected:'] class ErrorInfo: @@ -261,7 +262,8 @@ def report(self, severity: str = 'error', file: Optional[str] = None, only_once: bool = False, - origin_line: Optional[int] = None) -> None: + origin_line: Optional[int] = None, + offset: int = 0) -> None: """Report message at the given line using the current error context. Args: @@ -278,6 +280,8 @@ def report(self, type = None # Omit type context if nested function if file is None: file = self.file + if offset: + message = " " * offset + message info = ErrorInfo(self.import_context(), file, self.current_module(), type, self.function_or_member[-1], line, column, severity, message, blocker, only_once, @@ -479,6 +483,10 @@ def remove_duplicates(self, errors: List[Tuple[Optional[str], int, int, str, str while (j >= 0 and errors[j][0] == errors[i][0] and errors[j][1] == errors[i][1]): if (errors[j][3] == errors[i][3] and + # Allow duplicate notes in overload conficts reporting + not (errors[i][3] == 'note' and + errors[i][4].strip() in allowed_duplicates + or errors[i][4].strip().startswith('def ')) and errors[j][4] == errors[i][4]): # ignore column dup = True break diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 12b7e2db69bd..8ba3f193e962 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -514,7 +514,7 @@ def visit_Assign(self, n: ast3.Assign) -> AssignmentStmt: @with_line def visit_AnnAssign(self, n: ast3.AnnAssign) -> AssignmentStmt: if n.value is None: # always allow 'x: int' - rvalue = TempNode(AnyType(TypeOfAny.special_form)) # type: Expression + rvalue = TempNode(AnyType(TypeOfAny.special_form), no_rhs=True) # type: Expression else: rvalue = self.visit(n.value) typ = TypeConverter(self.errors, line=n.lineno).visit(n.annotation) diff --git a/mypy/join.py b/mypy/join.py index c771b6df86b9..e306a3f35895 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -9,7 +9,10 @@ PartialType, DeletedType, UninhabitedType, TypeType, true_or_false, TypeOfAny ) from mypy.maptype import map_instance_to_supertype -from mypy.subtypes import is_subtype, is_equivalent, is_subtype_ignoring_tvars, is_proper_subtype +from mypy.subtypes import ( + is_subtype, is_equivalent, is_subtype_ignoring_tvars, is_proper_subtype, + is_protocol_implementation +) from mypy import experiments @@ -137,7 +140,18 @@ def visit_type_var(self, t: TypeVarType) -> Type: def visit_instance(self, t: Instance) -> Type: if isinstance(self.s, Instance): - return join_instances(t, self.s) + nominal = join_instances(t, self.s) + structural = None # type: Optional[Instance] + if t.type.is_protocol and is_protocol_implementation(self.s, t): + structural = t + elif self.s.type.is_protocol and is_protocol_implementation(t, self.s): + structural = self.s + # Structural join is preferred in the case where we have found both + # structural and nominal and they have same MRO length (see two comments + # in join_instances_via_supertype). Otherwise, just return the nominal join. + if not structural or is_better(nominal, structural): + return nominal + return structural elif isinstance(self.s, FunctionLike): return join_types(t, self.s.fallback) elif isinstance(self.s, TypeType): @@ -237,7 +251,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type: required_keys = set(items.keys()) & t.required_keys & self.s.required_keys return TypedDictType(items, required_keys, fallback) elif isinstance(self.s, Instance): - return join_instances(self.s, t.fallback) + return join_types(self.s, t.fallback) else: return self.default(self.s) diff --git a/mypy/meet.py b/mypy/meet.py index 632e0bbc5afa..33d7f6e3df4a 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -7,7 +7,7 @@ TupleType, TypedDictType, ErasedType, TypeList, UnionType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny ) -from mypy.subtypes import is_equivalent, is_subtype +from mypy.subtypes import is_equivalent, is_subtype, is_protocol_implementation from mypy import experiments @@ -52,7 +52,8 @@ def is_overlapping_types(t: Type, s: Type, use_promotions: bool = False) -> bool Note that this effectively checks against erased types, since type variables are erased at runtime and the overlapping check is based - on runtime behavior. + on runtime behavior. The exception is protocol types, it is not safe, + but convenient and is an opt-in behavior. If use_promotions is True, also consider type promotions (int and float would only be overlapping if it's True). @@ -100,7 +101,13 @@ class C(A, B): ... return True if s.type._promote and is_overlapping_types(s.type._promote, t): return True - return t.type in s.type.mro or s.type in t.type.mro + if t.type in s.type.mro or s.type in t.type.mro: + return True + if t.type.is_protocol and is_protocol_implementation(s, t): + return True + if s.type.is_protocol and is_protocol_implementation(t, s): + return True + return False if isinstance(t, UnionType): return any(is_overlapping_types(item, s) for item in t.relevant_items()) diff --git a/mypy/messages.py b/mypy/messages.py index 022d962f6706..21bf2bf85334 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -6,7 +6,7 @@ import re import difflib -from typing import cast, List, Dict, Any, Sequence, Iterable, Tuple, Optional +from typing import cast, List, Dict, Any, Sequence, Iterable, Tuple, Set, Optional, Union from mypy.erasetype import erase_type from mypy.errors import Errors @@ -18,7 +18,7 @@ from mypy.nodes import ( TypeInfo, Context, MypyFile, op_methods, FuncDef, reverse_type_aliases, ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2, - ReturnStmt, NameExpr, Var + ReturnStmt, NameExpr, Var, CONTRAVARIANT, COVARIANT ) @@ -159,12 +159,13 @@ def is_errors(self) -> bool: return self.errors.is_errors() def report(self, msg: str, context: Context, severity: str, - file: Optional[str] = None, origin: Optional[Context] = None) -> None: + file: Optional[str] = None, origin: Optional[Context] = None, + offset: int = 0) -> None: """Report an error or note (unless disabled).""" if self.disable_count <= 0: self.errors.report(context.get_line() if context else -1, context.get_column() if context else -1, - msg.strip(), severity=severity, file=file, + msg.strip(), severity=severity, file=file, offset=offset, origin_line=origin.get_line() if origin else None) def fail(self, msg: str, context: Context, file: Optional[str] = None, @@ -173,9 +174,9 @@ def fail(self, msg: str, context: Context, file: Optional[str] = None, self.report(msg, context, 'error', file=file, origin=origin) def note(self, msg: str, context: Context, file: Optional[str] = None, - origin: Optional[Context] = None) -> None: + origin: Optional[Context] = None, offset: int = 0) -> None: """Report a note (unless disabled).""" - self.report(msg, context, 'note', file=file, origin=origin) + self.report(msg, context, 'note', file=file, origin=origin, offset=offset) def warn(self, msg: str, context: Context, file: Optional[str] = None, origin: Optional[Context] = None) -> None: @@ -998,6 +999,252 @@ def untyped_decorated_function(self, typ: Type, context: Context) -> None: def typed_function_untyped_decorator(self, func_name: str, context: Context) -> None: self.fail('Untyped decorator makes function "{}" untyped'.format(func_name), context) + def bad_proto_variance(self, actual: int, tvar_name: str, expected: int, + context: Context) -> None: + msg = capitalize("{} type variable '{}' used in protocol where" + " {} one is expected".format(variance_string(actual), + tvar_name, + variance_string(expected))) + self.fail(msg, context) + + def concrete_only_assign(self, typ: Type, context: Context) -> None: + self.fail("Can only assign concrete classes to a variable of type '{}'" + .format(self.format(typ)), context) + + def concrete_only_call(self, typ: Type, context: Context) -> None: + self.fail("Only concrete class can be given where '{}' is expected" + .format(self.format(typ)), context) + + def note_call(self, subtype: Type, call: Type, context: Context) -> None: + self.note("'{}.__call__' has type '{}'".format(strip_quotes(self.format(subtype)), + self.format(call, verbosity=1)), context) + + def report_protocol_problems(self, subtype: Union[Instance, TupleType, TypedDictType], + supertype: Instance, context: Context) -> None: + """Report possible protocol conflicts between 'subtype' and 'supertype'. + + This includes missing members, incompatible types, and incompatible + attribute flags, such as settable vs read-only or class variable vs + instance variable. + """ + from mypy.subtypes import is_subtype, IS_SETTABLE, IS_CLASSVAR, IS_CLASS_OR_STATIC + OFFSET = 4 # Four spaces, so that notes will look like this: + # note: 'Cls' is missing following 'Proto' members: + # note: method, attr + MAX_ITEMS = 2 # Maximum number of conflicts, missing members, and overloads shown + # List of special situations where we don't want to report additional problems + exclusions = {TypedDictType: ['typing.Mapping'], + TupleType: ['typing.Iterable', 'typing.Sequence'], + Instance: []} # type: Dict[type, List[str]] + if supertype.type.fullname() in exclusions[type(subtype)]: + return + if any(isinstance(tp, UninhabitedType) for tp in supertype.args): + # We don't want to add notes for failed inference (e.g. Iterable[]). + # This will be only confusing a user even more. + return + + if isinstance(subtype, (TupleType, TypedDictType)): + if not isinstance(subtype.fallback, Instance): + return + subtype = subtype.fallback + + # Report missing members + missing = get_missing_protocol_members(subtype, supertype) + if (missing and len(missing) < len(supertype.type.protocol_members) and + len(missing) <= MAX_ITEMS): + self.note("'{}' is missing following '{}' protocol member{}:" + .format(subtype.type.name(), supertype.type.name(), plural_s(missing)), + context) + self.note(', '.join(missing), context, offset=OFFSET) + elif len(missing) > MAX_ITEMS or len(missing) == len(supertype.type.protocol_members): + # This is an obviously wrong type: too many missing members + return + + # Report member type conflicts + conflict_types = get_conflict_protocol_types(subtype, supertype) + if conflict_types and (not is_subtype(subtype, erase_type(supertype)) or + not subtype.type.defn.type_vars or + not supertype.type.defn.type_vars): + self.note('Following member(s) of {} have ' + 'conflicts:'.format(self.format(subtype)), context) + for name, got, exp in conflict_types[:MAX_ITEMS]: + if (not isinstance(exp, (CallableType, Overloaded)) or + not isinstance(got, (CallableType, Overloaded))): + self.note('{}: expected {}, got {}'.format(name, + *self.format_distinctly(exp, got)), + context, offset=OFFSET) + else: + self.note('Expected:', context, offset=OFFSET) + if isinstance(exp, CallableType): + self.note(self.pretty_callable(exp), context, offset=2 * OFFSET) + else: + assert isinstance(exp, Overloaded) + self.pretty_overload(exp, context, OFFSET, MAX_ITEMS) + self.note('Got:', context, offset=OFFSET) + if isinstance(got, CallableType): + self.note(self.pretty_callable(got), context, offset=2 * OFFSET) + else: + assert isinstance(got, Overloaded) + self.pretty_overload(got, context, OFFSET, MAX_ITEMS) + self.print_more(conflict_types, context, OFFSET, MAX_ITEMS) + + # Report flag conflicts (i.e. settable vs read-only etc.) + conflict_flags = get_bad_protocol_flags(subtype, supertype) + for name, subflags, superflags in conflict_flags[:MAX_ITEMS]: + if IS_CLASSVAR in subflags and IS_CLASSVAR not in superflags: + self.note('Protocol member {}.{} expected instance variable,' + ' got class variable'.format(supertype.type.name(), name), context) + if IS_CLASSVAR in superflags and IS_CLASSVAR not in subflags: + self.note('Protocol member {}.{} expected class variable,' + ' got instance variable'.format(supertype.type.name(), name), context) + if IS_SETTABLE in superflags and IS_SETTABLE not in subflags: + self.note('Protocol member {}.{} expected settable variable,' + ' got read-only attribute'.format(supertype.type.name(), name), context) + if IS_CLASS_OR_STATIC in superflags and IS_CLASS_OR_STATIC not in subflags: + self.note('Protocol member {}.{} expected class or static method' + .format(supertype.type.name(), name), context) + self.print_more(conflict_flags, context, OFFSET, MAX_ITEMS) + + def pretty_overload(self, tp: Overloaded, context: Context, + offset: int, max_items: int) -> None: + for item in tp.items()[:max_items]: + self.note('@overload', context, offset=2 * offset) + self.note(self.pretty_callable(item), context, offset=2 * offset) + if len(tp.items()) > max_items: + self.note('<{} more overload(s) not shown>'.format(len(tp.items()) - max_items), + context, offset=2 * offset) + + def print_more(self, conflicts: Sequence[Any], context: Context, + offset: int, max_items: int) -> None: + if len(conflicts) > max_items: + self.note('<{} more conflict(s) not shown>' + .format(len(conflicts) - max_items), + context, offset=offset) + + def pretty_callable(self, tp: CallableType) -> str: + """Return a nice easily-readable representation of a callable type. + For example: + def [T <: int] f(self, x: int, y: T) -> None + """ + s = '' + asterisk = False + for i in range(len(tp.arg_types)): + if s: + s += ', ' + if tp.arg_kinds[i] in (ARG_NAMED, ARG_NAMED_OPT) and not asterisk: + s += '*, ' + asterisk = True + if tp.arg_kinds[i] == ARG_STAR: + s += '*' + asterisk = True + if tp.arg_kinds[i] == ARG_STAR2: + s += '**' + name = tp.arg_names[i] + if name: + s += name + ': ' + s += strip_quotes(self.format(tp.arg_types[i])) + if tp.arg_kinds[i] in (ARG_OPT, ARG_NAMED_OPT): + s += ' = ...' + + # If we got a "special arg" (i.e: self, cls, etc...), prepend it to the arg list + if tp.definition is not None and tp.definition.name() is not None: + definition_args = getattr(tp.definition, 'arg_names') + if definition_args and tp.arg_names != definition_args \ + and len(definition_args) > 0: + if s: + s = ', ' + s + s = definition_args[0] + s + s = '{}({})'.format(tp.definition.name(), s) + else: + s = '({})'.format(s) + + s += ' -> ' + strip_quotes(self.format(tp.ret_type)) + if tp.variables: + tvars = [] + for tvar in tp.variables: + if (tvar.upper_bound and isinstance(tvar.upper_bound, Instance) and + tvar.upper_bound.type.fullname() != 'builtins.object'): + tvars.append('{} <: {}'.format(tvar.name, + strip_quotes(self.format(tvar.upper_bound)))) + elif tvar.values: + tvars.append('{} in ({})' + .format(tvar.name, ', '.join([strip_quotes(self.format(tp)) + for tp in tvar.values]))) + else: + tvars.append(tvar.name) + s = '[{}] {}'.format(', '.join(tvars), s) + return 'def {}'.format(s) + + +def variance_string(variance: int) -> str: + if variance == COVARIANT: + return 'covariant' + elif variance == CONTRAVARIANT: + return 'contravariant' + else: + return 'invariant' + + +def get_missing_protocol_members(left: Instance, right: Instance) -> List[str]: + """Find all protocol members of 'right' that are not implemented + (i.e. completely missing) in 'left'. + """ + from mypy.subtypes import find_member + assert right.type.is_protocol + missing = [] # type: List[str] + for member in right.type.protocol_members: + if not find_member(member, left, left): + missing.append(member) + return missing + + +def get_conflict_protocol_types(left: Instance, right: Instance) -> List[Tuple[str, Type, Type]]: + """Find members that are defined in 'left' but have incompatible types. + Return them as a list of ('member', 'got', 'expected'). + """ + from mypy.subtypes import find_member, is_subtype, get_member_flags, IS_SETTABLE + assert right.type.is_protocol + conflicts = [] # type: List[Tuple[str, Type, Type]] + for member in right.type.protocol_members: + if member in ('__init__', '__new__'): + continue + supertype = find_member(member, right, left) + assert supertype is not None + subtype = find_member(member, left, left) + if not subtype: + continue + is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True) + if IS_SETTABLE in get_member_flags(member, right.type): + is_compat = is_compat and is_subtype(supertype, subtype) + if not is_compat: + conflicts.append((member, subtype, supertype)) + return conflicts + + +def get_bad_protocol_flags(left: Instance, right: Instance + ) -> List[Tuple[str, Set[int], Set[int]]]: + """Return all incompatible attribute flags for members that are present in both + 'left' and 'right'. + """ + from mypy.subtypes import (find_member, get_member_flags, + IS_SETTABLE, IS_CLASSVAR, IS_CLASS_OR_STATIC) + assert right.type.is_protocol + all_flags = [] # type: List[Tuple[str, Set[int], Set[int]]] + for member in right.type.protocol_members: + if find_member(member, left, left): + item = (member, + get_member_flags(member, left.type), + get_member_flags(member, right.type)) + all_flags.append(item) + bad_flags = [] + for name, subflags, superflags in all_flags: + if (IS_CLASSVAR in subflags and IS_CLASSVAR not in superflags or + IS_CLASSVAR in superflags and IS_CLASSVAR not in subflags or + IS_SETTABLE in superflags and IS_SETTABLE not in subflags or + IS_CLASS_OR_STATIC in superflags and IS_CLASS_OR_STATIC not in subflags): + bad_flags.append((name, subflags, superflags)) + return bad_flags + def capitalize(s: str) -> str: """Capitalize the first character of a string.""" diff --git a/mypy/nodes.py b/mypy/nodes.py index 16d57c418502..400339c083c1 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -669,6 +669,7 @@ class Var(SymbolNode): is_property = False is_settable_property = False is_classvar = False + is_abstract_var = False # Set to true when this variable refers to a module we were unable to # parse for some reason (eg a silenced module) is_suppressed_import = False @@ -676,7 +677,7 @@ class Var(SymbolNode): FLAGS = [ 'is_self', 'is_ready', 'is_initialized_in_class', 'is_staticmethod', 'is_classmethod', 'is_property', 'is_settable_property', 'is_suppressed_import', - 'is_classvar' + 'is_classvar', 'is_abstract_var' ] def __init__(self, name: str, type: 'Optional[mypy.types.Type]' = None) -> None: @@ -1932,9 +1933,13 @@ class TempNode(Expression): """ type = None # type: mypy.types.Type + # Is this TempNode used to indicate absence of a right hand side in an annotated assignment? + # (e.g. for 'x: int' the rvalue is TempNode(AnyType(TypeOfAny.special_form), no_rhs=True)) + no_rhs = False # type: bool - def __init__(self, typ: 'mypy.types.Type') -> None: + def __init__(self, typ: 'mypy.types.Type', no_rhs: bool = False) -> None: self.type = typ + self.no_rhs = no_rhs def __repr__(self) -> str: return 'TempNode(%s)' % str(self.type) @@ -1972,7 +1977,53 @@ class is generic then it will be a type constructor of higher kind. subtypes = None # type: Set[TypeInfo] # Direct subclasses encountered so far names = None # type: SymbolTable # Names defined directly in this type is_abstract = False # Does the class have any abstract attributes? + is_protocol = False # Is this a protocol class? + runtime_protocol = False # Does this protocol support isinstance checks? abstract_attributes = None # type: List[str] + # Protocol members are names of all attributes/methods defined in a protocol + # and in all its supertypes (except for 'object'). + protocol_members = None # type: List[str] + + # The attributes 'assuming' and 'assuming_proper' represent structural subtype matrices. + # + # In languages with structural subtyping, one can keep a global subtype matrix like this: + # . A B C . + # A 1 0 0 + # B 1 1 1 + # C 1 0 1 + # . + # where 1 indicates that the type in corresponding row is a subtype of the type + # in corresponding column. This matrix typically starts filled with all 1's and + # a typechecker tries to "disprove" every subtyping relation using atomic (or nominal) types. + # However, we don't want to keep this huge global state. Instead, we keep the subtype + # information in the form of list of pairs (subtype, supertype) shared by all 'Instance's + # with given supertype's TypeInfo. When we enter a subtype check we push a pair in this list + # thus assuming that we started with 1 in corresponding matrix element. Such algorithm allows + # to treat recursive and mutually recursive protocols and other kinds of complex situations. + # + # If concurrent/parallel type checking will be added in future, + # then there should be one matrix per thread/process to avoid false negatives + # during the type checking phase. + assuming = None # type: List[Tuple[mypy.types.Instance, mypy.types.Instance]] + assuming_proper = None # type: List[Tuple[mypy.types.Instance, mypy.types.Instance]] + # Ditto for temporary 'inferring' stack of recursive constraint inference. + # It contains Instance's of protocol types that appeared as an argument to + # constraints.infer_constraints(). We need 'inferring' to avoid infinite recursion for + # recursive and mutually recursive protocols. + # + # We make 'assuming' and 'inferring' attributes here instead of passing they as kwargs, + # since this would require to pass them in many dozens of calls. In particular, + # there is a dependency infer_constraint -> is_subtype -> is_callable_subtype -> + # -> infer_constraints. + inferring = None # type: List[mypy.types.Instance] + # '_cache' and '_cache_proper' are subtype caches, implemented as sets of pairs + # of (subtype, supertype), where supertypes are instances of given TypeInfo. + # We need the caches, since subtype checks for structural types are very slow. + _cache = None # type: Set[Tuple[mypy.types.Type, mypy.types.Type]] + _cache_proper = None # type: Set[Tuple[mypy.types.Type, mypy.types.Type]] + # 'inferring' and 'assuming' can't be also made sets, since we need to use + # is_same_type to correctly treat unions. + # Classes inheriting from Enum shadow their true members with a __getattr__, so we # have to treat them as a special case. is_enum = False @@ -2016,7 +2067,7 @@ class is generic then it will be a type constructor of higher kind. FLAGS = [ 'is_abstract', 'is_enum', 'fallback_to_any', 'is_named_tuple', - 'is_newtype' + 'is_newtype', 'is_protocol', 'runtime_protocol' ] def __init__(self, names: 'SymbolTable', defn: ClassDef, module_name: str) -> None: @@ -2032,6 +2083,11 @@ def __init__(self, names: 'SymbolTable', defn: ClassDef, module_name: str) -> No self._fullname = defn.fullname self.is_abstract = False self.abstract_attributes = [] + self.assuming = [] + self.assuming_proper = [] + self.inferring = [] + self._cache = set() + self._cache_proper = set() self.add_type_vars() def add_type_vars(self) -> None: @@ -2051,6 +2107,9 @@ def is_generic(self) -> bool: return len(self.type_vars) > 0 def get(self, name: str) -> 'Optional[SymbolTableNode]': + if self.mro is None: # Might be because of a previous error. + return None + for cls in self.mro: n = cls.names.get(name) if n: @@ -2063,6 +2122,21 @@ def get_containing_type_info(self, name: str) -> 'Optional[TypeInfo]': return cls return None + def record_subtype_cache_entry(self, left: 'mypy.types.Instance', + right: 'mypy.types.Instance', + proper_subtype: bool = False) -> None: + if proper_subtype: + self._cache_proper.add((left, right)) + else: + self._cache.add((left, right)) + + def is_cached_subtype_check(self, left: 'mypy.types.Instance', + right: 'mypy.types.Instance', + proper_subtype: bool = False) -> bool: + if not proper_subtype: + return (left, right) in self._cache + return (left, right) in self._cache_proper + def __getitem__(self, name: str) -> 'SymbolTableNode': n = self.get(name) if n: @@ -2203,6 +2277,7 @@ def serialize(self) -> JsonDict: 'names': self.names.serialize(self.fullname()), 'defn': self.defn.serialize(), 'abstract_attributes': self.abstract_attributes, + 'protocol_members': self.protocol_members, 'type_vars': self.type_vars, 'bases': [b.serialize() for b in self.bases], '_promote': None if self._promote is None else self._promote.serialize(), @@ -2226,6 +2301,7 @@ def deserialize(cls, data: JsonDict) -> 'TypeInfo': ti._fullname = data['fullname'] # TODO: Is there a reason to reconstruct ti.subtypes? ti.abstract_attributes = data['abstract_attributes'] + ti.protocol_members = data['protocol_members'] ti.type_vars = data['type_vars'] ti.bases = [mypy.types.Instance.deserialize(b) for b in data['bases']] ti._promote = (None if data['_promote'] is None diff --git a/mypy/sametypes.py b/mypy/sametypes.py index 0531ecc2d474..cba80e1ef825 100644 --- a/mypy/sametypes.py +++ b/mypy/sametypes.py @@ -68,8 +68,8 @@ def visit_erased_type(self, left: ErasedType) -> bool: # We can get here when isinstance is used inside a lambda # whose type is being inferred. In any event, we have no reason # to think that an ErasedType will end up being the same as - # any other type, even another ErasedType. - return False + # any other type, except another ErasedType (for protocols). + return isinstance(self.right, ErasedType) def visit_deleted_type(self, left: DeletedType) -> bool: return isinstance(self.right, DeletedType) diff --git a/mypy/semanal.py b/mypy/semanal.py index 6f1ee503b19c..e413c7fd5882 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -543,11 +543,17 @@ def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: if defn.impl is not None: assert defn.impl is defn.items[-1] defn.items = defn.items[:-1] - elif not self.is_stub_file and not non_overload_indexes: - self.fail( - "An overloaded function outside a stub file must have an implementation", - defn) + if not (self.is_class_scope() and self.type.is_protocol): + self.fail( + "An overloaded function outside a stub file must have an implementation", + defn) + else: + for item in defn.items: + if isinstance(item, Decorator): + item.func.is_abstract = True + else: + item.is_abstract = True if types: defn.type = Overloaded(types) @@ -671,6 +677,7 @@ def visit_class_def(self, defn: ClassDef) -> None: @contextmanager def analyze_class_body(self, defn: ClassDef) -> Iterator[bool]: with self.tvar_scope_frame(self.tvar_scope.class_frame()): + is_protocol = self.detect_protocol_base(defn) self.clean_up_bases_and_infer_type_variables(defn) self.analyze_class_keywords(defn) if self.analyze_typeddict_classdef(defn): @@ -710,13 +717,12 @@ def analyze_class_body(self, defn: ClassDef) -> Iterator[bool]: self.setup_class_def_analysis(defn) self.analyze_base_classes(defn) self.analyze_metaclass(defn) - + defn.info.is_protocol = is_protocol + defn.info.runtime_protocol = False for decorator in defn.decorators: self.analyze_class_decorator(defn, decorator) - self.enter_class(defn.info) yield True - self.calculate_abstract_status(defn.info) self.setup_type_promotion(defn) @@ -743,6 +749,12 @@ def leave_class(self) -> None: def analyze_class_decorator(self, defn: ClassDef, decorator: Expression) -> None: decorator.accept(self) + if (isinstance(decorator, RefExpr) and + decorator.fullname in ('typing.runtime', 'typing_extensions.runtime')): + if defn.info.is_protocol: + defn.info.runtime_protocol = True + else: + self.fail('@runtime can only be used with protocol classes', defn) def calculate_abstract_status(self, typ: TypeInfo) -> None: """Calculate abstract status of a class. @@ -768,6 +780,10 @@ def calculate_abstract_status(self, typ: TypeInfo) -> None: if fdef.is_abstract and name not in concrete: typ.is_abstract = True abstract.append(name) + elif isinstance(node, Var): + if node.is_abstract_var and name not in concrete: + typ.is_abstract = True + abstract.append(name) concrete.add(name) typ.abstract_attributes = sorted(abstract) @@ -790,6 +806,21 @@ def setup_type_promotion(self, defn: ClassDef) -> None: promote_target = self.named_type_or_none(promotions[defn.fullname]) defn.info._promote = promote_target + def detect_protocol_base(self, defn: ClassDef) -> bool: + for base_expr in defn.base_type_exprs: + try: + base = expr_to_unanalyzed_type(base_expr) + except TypeTranslationError: + continue # This will be reported later + if not isinstance(base, UnboundType): + continue + sym = self.lookup_qualified(base.name, base) + if sym is None or sym.node is None: + continue + if sym.node.fullname() in ('typing.Protocol', 'typing_extensions.Protocol'): + return True + return False + def clean_up_bases_and_infer_type_variables(self, defn: ClassDef) -> None: """Remove extra base classes such as Generic and infer type vars. @@ -817,17 +848,26 @@ def clean_up_bases_and_infer_type_variables(self, defn: ClassDef) -> None: tvars = self.analyze_typevar_declaration(base) if tvars is not None: if declared_tvars: - self.fail('Duplicate Generic in bases', defn) + self.fail('Only single Generic[...] or Protocol[...] can be in bases', defn) removed.append(i) declared_tvars.extend(tvars) + if isinstance(base, UnboundType): + sym = self.lookup_qualified(base.name, base) + if sym is not None and sym.node is not None: + if (sym.node.fullname() in ('typing.Protocol', + 'typing_extensions.Protocol') and + i not in removed): + # also remove bare 'Protocol' bases + removed.append(i) all_tvars = self.get_all_bases_tvars(defn, removed) if declared_tvars: if len(remove_dups(declared_tvars)) < len(declared_tvars): - self.fail("Duplicate type variables in Generic[...]", defn) + self.fail("Duplicate type variables in Generic[...] or Protocol[...]", defn) declared_tvars = remove_dups(declared_tvars) if not set(all_tvars).issubset(set(declared_tvars)): - self.fail("If Generic[...] is present it should list all type variables", defn) + self.fail("If Generic[...] or Protocol[...] is present" + " it should list all type variables", defn) # In case of error, Generic tvars will go first declared_tvars = remove_dups(declared_tvars + all_tvars) else: @@ -849,7 +889,9 @@ def analyze_typevar_declaration(self, t: Type) -> Optional[TypeVarList]: sym = self.lookup_qualified(unbound.name, unbound) if sym is None or sym.node is None: return None - if sym.node.fullname() == 'typing.Generic': + if (sym.node.fullname() == 'typing.Generic' or + sym.node.fullname() == 'typing.Protocol' and t.args or + sym.node.fullname() == 'typing_extensions.Protocol' and t.args): tvars = [] # type: TypeVarList for arg in unbound.args: tvar = self.analyze_unbound_tvar(arg) @@ -1594,10 +1636,16 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: if s.type: allow_tuple_literal = isinstance(s.lvalues[-1], (TupleExpr, ListExpr)) s.type = self.anal_type(s.type, allow_tuple_literal=allow_tuple_literal) + if (self.type and self.type.is_protocol and isinstance(lval, NameExpr) and + isinstance(s.rvalue, TempNode) and s.rvalue.no_rhs): + if isinstance(lval.node, Var): + lval.node.is_abstract_var = True else: - # Set the type if the rvalue is a simple literal. - if (s.type is None and len(s.lvalues) == 1 and - isinstance(s.lvalues[0], NameExpr)): + if (any(isinstance(lv, NameExpr) and lv.is_def for lv in s.lvalues) and + self.type and self.type.is_protocol and not self.is_func_scope()): + self.fail('All protocol members must have explicitly declared types', s) + # Set the type if the rvalue is a simple literal (even if the above error occurred). + if len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr): if s.lvalues[0].is_def: s.type = self.analyze_simple_literal_type(s.rvalue) if s.type: @@ -1832,18 +1880,22 @@ def analyze_tuple_or_list_lvalue(self, lval: Union[ListExpr, TupleExpr], def analyze_member_lvalue(self, lval: MemberExpr) -> None: lval.accept(self) - if (self.is_self_member_ref(lval) and - self.type.get(lval.name) is None): - # Implicit attribute definition in __init__. - lval.is_def = True - v = Var(lval.name) - v.set_line(lval) - v._fullname = self.qualified_name(lval.name) - v.info = self.type - v.is_ready = False - lval.def_var = v - lval.node = v - self.type.names[lval.name] = SymbolTableNode(MDEF, v, implicit=True) + if self.is_self_member_ref(lval): + node = self.type.get(lval.name) + if node is None or isinstance(node.node, Var) and node.node.is_abstract_var: + if self.type.is_protocol and node is None: + self.fail("Protocol members cannot be defined via assignment to self", lval) + else: + # Implicit attribute definition in __init__. + lval.is_def = True + v = Var(lval.name) + v.set_line(lval) + v._fullname = self.qualified_name(lval.name) + v.info = self.type + v.is_ready = False + lval.def_var = v + lval.node = v + self.type.names[lval.name] = SymbolTableNode(MDEF, v, implicit=True) self.check_lvalue_validity(lval.node, lval) def is_self_member_ref(self, memberexpr: MemberExpr) -> bool: @@ -1906,6 +1958,8 @@ def process_newtype_declaration(self, s: AssignmentStmt) -> None: newtype_class_info = self.build_newtype_typeinfo(name, old_type, old_type.fallback) newtype_class_info.tuple_type = old_type elif isinstance(old_type, Instance): + if old_type.type.is_protocol: + self.fail("NewType cannot be used with protocol classes", s) newtype_class_info = self.build_newtype_typeinfo(name, old_type, old_type) else: message = "Argument 2 to NewType(...) must be subclassable (got {})" @@ -3778,6 +3832,11 @@ def visit_file(self, file: MypyFile, fnam: str, mod_id: str, options: Options) - ('False', bool_type), ('__debug__', bool_type), ]) + else: + # We are running tests without 'bool' in builtins. + # TODO: Find a permanent solution to this problem. + # Maybe add 'bool' to all fixtures? + literal_types.append(('True', AnyType(TypeOfAny.special_form))) for name, typ in literal_types: v = Var(name, typ) @@ -4018,12 +4077,18 @@ def visit_class_def(self, tdef: ClassDef) -> None: if not tdef.info.is_named_tuple: for type in tdef.info.bases: self.analyze(type) + if tdef.info.is_protocol: + if not isinstance(type, Instance) or not type.type.is_protocol: + if type.type.fullname() != 'builtins.object': + self.fail('All bases of a protocol must be protocols', tdef) # Recompute MRO now that we have analyzed all modules, to pick # up superclasses of bases imported from other modules in an # import loop. (Only do so if we succeeded the first time.) if tdef.info.mro: tdef.info.mro = [] # Force recomputation calculate_class_mro(tdef, self.fail_blocker) + if tdef.info.is_protocol: + add_protocol_members(tdef.info) if tdef.analyzed is not None: if isinstance(tdef.analyzed, TypedDictExpr): self.analyze(tdef.analyzed.info.typeddict_type) @@ -4141,6 +4206,16 @@ def builtin_type(self, name: str, args: List[Type] = None) -> Instance: return Instance(node, [any_type] * len(node.defn.type_vars)) +def add_protocol_members(typ: TypeInfo) -> None: + members = set() # type: Set[str] + if typ.mro: + for base in typ.mro[:-1]: # we skip "object" since everyone implements it + if base.is_protocol: + for name in base.names: + members.add(name) + typ.protocol_members = sorted(list(members)) + + def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike: if isinstance(sig, CallableType): return sig.copy_modified(arg_types=[new] + sig.arg_types[1:]) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 97701598921c..5dfbc7c405bc 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1,9 +1,11 @@ -from typing import List, Optional, Dict, Callable, cast +from typing import List, Optional, Dict, Callable, Tuple, Iterator, Set, Union, cast +from contextlib import contextmanager from mypy.types import ( - Type, AnyType, UnboundType, TypeVisitor, FormalArgument, NoneTyp, Instance, TypeVarType, - CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, TypeList, - PartialType, DeletedType, UninhabitedType, TypeType, is_named_instance, TypeOfAny + Type, AnyType, UnboundType, TypeVisitor, FormalArgument, NoneTyp, function_type, + Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, + ErasedType, TypeList, PartialType, DeletedType, UninhabitedType, TypeType, is_named_instance, + FunctionLike, TypeOfAny ) import mypy.applytype import mypy.constraints @@ -12,15 +14,22 @@ # import mypy.solve from mypy import messages, sametypes from mypy.nodes import ( - CONTRAVARIANT, COVARIANT, - ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2, + FuncBase, Var, Decorator, OverloadedFuncDef, TypeInfo, CONTRAVARIANT, COVARIANT, + ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2 ) from mypy.maptype import map_instance_to_supertype +from mypy.expandtype import expand_type_by_instance from mypy.sametypes import is_same_type from mypy import experiments +# Flags for detected protocol members +IS_SETTABLE = 1 +IS_CLASSVAR = 2 +IS_CLASS_OR_STATIC = 3 + + TypeParameterChecker = Callable[[Type, Type, int], bool] @@ -35,7 +44,8 @@ def check_type_parameter(lefta: Type, righta: Type, variance: int) -> bool: def is_subtype(left: Type, right: Type, type_parameter_checker: TypeParameterChecker = check_type_parameter, - *, ignore_pos_arg_names: bool = False) -> bool: + *, ignore_pos_arg_names: bool = False, + ignore_declared_variance: bool = False) -> bool: """Is 'left' subtype of 'right'? Also consider Any to be a subtype of any type, and vice versa. This @@ -69,7 +79,8 @@ def is_subtype(left: Type, right: Type, return True # otherwise, fall through return left.accept(SubtypeVisitor(right, type_parameter_checker, - ignore_pos_arg_names=ignore_pos_arg_names)) + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance)) def is_subtype_ignoring_tvars(left: Type, right: Type) -> bool: @@ -93,10 +104,12 @@ class SubtypeVisitor(TypeVisitor[bool]): def __init__(self, right: Type, type_parameter_checker: TypeParameterChecker, - *, ignore_pos_arg_names: bool = False) -> None: + *, ignore_pos_arg_names: bool = False, + ignore_declared_variance: bool = False) -> None: self.right = right self.check_type_parameter = type_parameter_checker self.ignore_pos_arg_names = ignore_pos_arg_names + self.ignore_declared_variance = ignore_declared_variance # visit_x(left) means: is left (which is an instance of X) a subtype of # right? @@ -130,24 +143,34 @@ def visit_instance(self, left: Instance) -> bool: if isinstance(right, TupleType) and right.fallback.type.is_enum: return is_subtype(left, right.fallback) if isinstance(right, Instance): - # NOTO: left.type.mro may be None in quick mode if there + if right.type.is_cached_subtype_check(left, right): + return True + # NOTE: left.type.mro may be None in quick mode if there # was an error somewhere. if left.type.mro is not None: for base in left.type.mro: + # TODO: Also pass recursively ignore_declared_variance if base._promote and is_subtype( base._promote, self.right, self.check_type_parameter, ignore_pos_arg_names=self.ignore_pos_arg_names): + right.type.record_subtype_cache_entry(left, right) return True rname = right.type.fullname() - if not left.type.has_base(rname) and rname != 'builtins.object': - return False - - # Map left type to corresponding right instances. - t = map_instance_to_supertype(left, right.type) - - return all(self.check_type_parameter(lefta, righta, tvar.variance) - for lefta, righta, tvar in - zip(t.args, right.args, right.type.defn.type_vars)) + # Always try a nominal check if possible, + # there might be errors that a user wants to silence *once*. + if ((left.type.has_base(rname) or rname == 'builtins.object') and + not self.ignore_declared_variance): + # Map left type to corresponding right instances. + t = map_instance_to_supertype(left, right.type) + nominal = all(self.check_type_parameter(lefta, righta, tvar.variance) + for lefta, righta, tvar in + zip(t.args, right.args, right.type.defn.type_vars)) + if nominal: + right.type.record_subtype_cache_entry(left, right) + return nominal + if right.type.is_protocol and is_protocol_implementation(left, right): + return True + return False if isinstance(right, TypeType): item = right.item if isinstance(item, TupleType): @@ -163,7 +186,14 @@ def visit_instance(self, left: Instance) -> bool: and is_named_instance(item, 'enum.Enum')): return True return is_named_instance(item, 'builtins.object') - return False + if isinstance(right, CallableType): + # Special case: Instance can be a subtype of Callable. + call = find_member('__call__', left, left) + if call: + return is_subtype(call, right) + return False + else: + return False def visit_type_var(self, left: TypeVarType) -> bool: right = self.right @@ -302,6 +332,190 @@ def visit_type_type(self, left: TypeType) -> bool: return False +@contextmanager +def pop_on_exit(stack: List[Tuple[Instance, Instance]], + left: Instance, right: Instance) -> Iterator[None]: + stack.append((left, right)) + yield + stack.pop() + + +def is_protocol_implementation(left: Instance, right: Instance, + proper_subtype: bool = False) -> bool: + """Check whether 'left' implements the protocol 'right'. + + If 'proper_subtype' is True, then check for a proper subtype. + Treat recursive protocols by using the 'assuming' structural subtype matrix + (in sparse representation, i.e. as a list of pairs (subtype, supertype)), + see also comment in nodes.TypeInfo. When we enter a check for classes + (A, P), defined as following:: + + class P(Protocol): + def f(self) -> P: ... + class A: + def f(self) -> A: ... + + this results in A being a subtype of P without infinite recursion. + On every false result, we pop the assumption, thus avoiding an infinite recursion + as well. + """ + assert right.type.is_protocol + assuming = right.type.assuming_proper if proper_subtype else right.type.assuming + for (l, r) in reversed(assuming): + if sametypes.is_same_type(l, left) and sametypes.is_same_type(r, right): + return True + with pop_on_exit(assuming, left, right): + for member in right.type.protocol_members: + # nominal subtyping currently ignores '__init__' and '__new__' signatures + if member in ('__init__', '__new__'): + continue + # The third argument below indicates to what self type is bound. + # We always bind self to the subtype. (Similarly to nominal types). + supertype = find_member(member, right, left) + assert supertype is not None + subtype = find_member(member, left, left) + # Useful for debugging: + # print(member, 'of', left, 'has type', subtype) + # print(member, 'of', right, 'has type', supertype) + if not subtype: + return False + if not proper_subtype: + # Nominal check currently ignores arg names + is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True) + else: + is_compat = is_proper_subtype(subtype, supertype) + if not is_compat: + return False + if isinstance(subtype, NoneTyp) and member.startswith('__') and member.endswith('__'): + # We want __hash__ = None idiom to work even without --strict-optional + return False + subflags = get_member_flags(member, left.type) + superflags = get_member_flags(member, right.type) + if IS_SETTABLE in superflags: + # Check opposite direction for settable attributes. + if not is_subtype(supertype, subtype): + return False + if (IS_CLASSVAR in subflags) != (IS_CLASSVAR in superflags): + return False + if IS_SETTABLE in superflags and IS_SETTABLE not in subflags: + return False + # This rule is copied from nominal check in checker.py + if IS_CLASS_OR_STATIC in superflags and IS_CLASS_OR_STATIC not in subflags: + return False + right.type.record_subtype_cache_entry(left, right, proper_subtype) + return True + + +def find_member(name: str, itype: Instance, subtype: Type) -> Optional[Type]: + """Find the type of member by 'name' in 'itype's TypeInfo. + + Fin the member type after applying type arguments from 'itype', and binding + 'self' to 'subtype'. Return None if member was not found. + """ + # TODO: this code shares some logic with checkmember.analyze_member_access, + # consider refactoring. + info = itype.type + method = info.get_method(name) + if method: + if method.is_property: + assert isinstance(method, OverloadedFuncDef) + dec = method.items[0] + assert isinstance(dec, Decorator) + return find_node_type(dec.var, itype, subtype) + return find_node_type(method, itype, subtype) + else: + # don't have such method, maybe variable or decorator? + node = info.get(name) + if not node: + v = None + else: + v = node.node + if isinstance(v, Decorator): + v = v.var + if isinstance(v, Var): + return find_node_type(v, itype, subtype) + if not v and name not in ['__getattr__', '__setattr__', '__getattribute__']: + for method_name in ('__getattribute__', '__getattr__'): + # Normally, mypy assumes that instances that define __getattr__ have all + # attributes with the corresponding return type. If this will produce + # many false negatives, then this could be prohibited for + # structural subtyping. + method = info.get_method(method_name) + if method and method.info.fullname() != 'builtins.object': + getattr_type = find_node_type(method, itype, subtype) + if isinstance(getattr_type, CallableType): + return getattr_type.ret_type + if itype.type.fallback_to_any: + return AnyType(TypeOfAny.special_form) + return None + + +def get_member_flags(name: str, info: TypeInfo) -> Set[int]: + """Detect whether a member 'name' is settable, whether it is an + instance or class variable, and whether it is class or static method. + + The flags are defined as following: + * IS_SETTABLE: whether this attribute can be set, not set for methods and + non-settable properties; + * IS_CLASSVAR: set if the variable is annotated as 'x: ClassVar[t]'; + * IS_CLASS_OR_STATIC: set for methods decorated with @classmethod or + with @staticmethod. + """ + method = info.get_method(name) + setattr_meth = info.get_method('__setattr__') + if method: + # this could be settable property + if method.is_property: + assert isinstance(method, OverloadedFuncDef) + dec = method.items[0] + assert isinstance(dec, Decorator) + if dec.var.is_settable_property or setattr_meth: + return {IS_SETTABLE} + return set() + node = info.get(name) + if not node: + if setattr_meth: + return {IS_SETTABLE} + return set() + v = node.node + if isinstance(v, Decorator): + if v.var.is_staticmethod or v.var.is_classmethod: + return {IS_CLASS_OR_STATIC} + # just a variable + if isinstance(v, Var): + flags = {IS_SETTABLE} + if v.is_classvar: + flags.add(IS_CLASSVAR) + return flags + return set() + + +def find_node_type(node: Union[Var, FuncBase], itype: Instance, subtype: Type) -> Type: + """Find type of a variable or method 'node' (maybe also a decorated method). + Apply type arguments from 'itype', and bind 'self' to 'subtype'. + """ + from mypy.checkmember import bind_self + if isinstance(node, FuncBase): + typ = function_type(node, + fallback=Instance(itype.type.mro[-1], [])) # type: Optional[Type] + else: + typ = node.type + if typ is None: + return AnyType(TypeOfAny.from_error) + # We don't need to bind 'self' for static methods, since there is no 'self'. + if isinstance(node, FuncBase) or isinstance(typ, FunctionLike) and not node.is_staticmethod: + assert isinstance(typ, FunctionLike) + signature = bind_self(typ, subtype) + if node.is_property: + assert isinstance(signature, CallableType) + typ = signature.ret_type + else: + typ = signature + itype = map_instance_to_supertype(itype, node.info) + typ = expand_type_by_instance(typ, itype) + return typ + + def is_callable_subtype(left: CallableType, right: CallableType, ignore_return: bool = False, ignore_pos_arg_names: bool = False, @@ -548,8 +762,11 @@ def restrict_subtype_away(t: Type, s: Type) -> Type: if isinstance(t, UnionType): # Since runtime type checks will ignore type arguments, erase the types. erased_s = erase_type(s) + # TODO: Implement more robust support for runtime isinstance() checks, + # see issue #3827 new_items = [item for item in t.relevant_items() - if (not is_proper_subtype(erase_type(item), erased_s) + if (not (is_proper_subtype(erase_type(item), erased_s) or + is_proper_subtype(item, erased_s)) or isinstance(item, AnyType))] return UnionType.make_union(new_items) else: @@ -601,26 +818,38 @@ def visit_deleted_type(self, left: DeletedType) -> bool: def visit_instance(self, left: Instance) -> bool: right = self.right if isinstance(right, Instance): + if right.type.is_cached_subtype_check(left, right, proper_subtype=True): + return True for base in left.type.mro: if base._promote and is_proper_subtype(base._promote, right): + right.type.record_subtype_cache_entry(left, right, proper_subtype=True) return True - if not left.type.has_base(right.type.fullname()): - return False - - def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool: - if variance == COVARIANT: - return is_proper_subtype(leftarg, rightarg) - elif variance == CONTRAVARIANT: - return is_proper_subtype(rightarg, leftarg) - else: - return sametypes.is_same_type(leftarg, rightarg) - - # Map left type to corresponding right instances. - left = map_instance_to_supertype(left, right.type) - - return all(check_argument(ta, ra, tvar.variance) for ta, ra, tvar in - zip(left.args, right.args, right.type.defn.type_vars)) + if left.type.has_base(right.type.fullname()): + def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool: + if variance == COVARIANT: + return is_proper_subtype(leftarg, rightarg) + elif variance == CONTRAVARIANT: + return is_proper_subtype(rightarg, leftarg) + else: + return sametypes.is_same_type(leftarg, rightarg) + # Map left type to corresponding right instances. + left = map_instance_to_supertype(left, right.type) + + nominal = all(check_argument(ta, ra, tvar.variance) for ta, ra, tvar in + zip(left.args, right.args, right.type.defn.type_vars)) + if nominal: + right.type.record_subtype_cache_entry(left, right, proper_subtype=True) + return nominal + if (right.type.is_protocol and + is_protocol_implementation(left, right, proper_subtype=True)): + return True + return False + if isinstance(right, CallableType): + call = find_member('__call__', left, left) + if call: + return is_proper_subtype(call, right) + return False return False def visit_type_var(self, left: TypeVarType) -> bool: diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 4ac59bd94f57..6685ed1011af 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -72,6 +72,7 @@ 'check-generic-subtyping.test', 'check-varargs.test', 'check-newsyntax.test', + 'check-protocols.test', 'check-underscores.test', 'check-classvar.test', 'check-enum.test', diff --git a/mypy/types.py b/mypy/types.py index a94b6e2b7152..192bf08eb35d 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -188,6 +188,15 @@ def __init__(self, def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_unbound_type(self) + def __hash__(self) -> int: + return hash((self.name, self.optional, tuple(self.args))) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, UnboundType): + return NotImplemented + return (self.name == other.name and self.optional == other.optional and + self.args == other.args) + def serialize(self) -> JsonDict: return {'.class': 'UnboundType', 'name': self.name, @@ -316,6 +325,12 @@ def copy_modified(self, return AnyType(type_of_any=type_of_any, source_any=original_any, line=self.line, column=self.column) + def __hash__(self) -> int: + return hash(AnyType) + + def __eq__(self, other: object) -> bool: + return isinstance(other, AnyType) + def serialize(self) -> JsonDict: return {'.class': 'AnyType'} @@ -350,6 +365,12 @@ def __init__(self, is_noreturn: bool = False, line: int = -1, column: int = -1) def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_uninhabited_type(self) + def __hash__(self) -> int: + return hash(UninhabitedType) + + def __eq__(self, other: object) -> bool: + return isinstance(other, UninhabitedType) + def serialize(self) -> JsonDict: return {'.class': 'UninhabitedType', 'is_noreturn': self.is_noreturn} @@ -371,6 +392,12 @@ class NoneTyp(Type): def __init__(self, line: int = -1, column: int = -1) -> None: super().__init__(line, column) + def __hash__(self) -> int: + return hash(NoneTyp) + + def __eq__(self, other: object) -> bool: + return isinstance(other, NoneTyp) + def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_none_type(self) @@ -450,6 +477,14 @@ def accept(self, visitor: 'TypeVisitor[T]') -> T: type_ref = None # type: str + def __hash__(self) -> int: + return hash((self.type, tuple(self.args))) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Instance): + return NotImplemented + return self.type == other.type and self.args == other.args + def serialize(self) -> Union[JsonDict, str]: assert self.type is not None type_ref = self.type.fullname() @@ -512,6 +547,14 @@ def erase_to_union_or_bound(self) -> Type: else: return self.upper_bound + def __hash__(self) -> int: + return hash(self.id) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TypeVarType): + return NotImplemented + return self.id == other.id + def serialize(self) -> JsonDict: assert not self.id.is_meta_var() return {'.class': 'TypeVarType', @@ -790,6 +833,23 @@ def type_var_ids(self) -> List[TypeVarId]: a.append(tv.id) return a + def __hash__(self) -> int: + return hash((self.ret_type, self.is_type_obj(), + self.is_ellipsis_args, self.name, + tuple(self.arg_types), tuple(self.arg_names), tuple(self.arg_kinds))) + + def __eq__(self, other: object) -> bool: + if isinstance(other, CallableType): + return (self.ret_type == other.ret_type and + self.arg_types == other.arg_types and + self.arg_names == other.arg_names and + self.arg_kinds == other.arg_kinds and + self.name == other.name and + self.is_type_obj() == other.is_type_obj() and + self.is_ellipsis_args == other.is_ellipsis_args) + else: + return NotImplemented + def serialize(self) -> JsonDict: # TODO: As an optimization, leave out everything related to # generic functions for non-generic functions. @@ -872,6 +932,14 @@ def get_name(self) -> Optional[str]: def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_overloaded(self) + def __hash__(self) -> int: + return hash(tuple(self.items())) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Overloaded): + return NotImplemented + return self.items() == other.items() + def serialize(self) -> JsonDict: return {'.class': 'Overloaded', 'items': [t.serialize() for t in self.items()], @@ -913,6 +981,14 @@ def length(self) -> int: def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_tuple_type(self) + def __hash__(self) -> int: + return hash((tuple(self.items), self.fallback)) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TupleType): + return NotImplemented + return self.items == other.items and self.fallback == other.fallback + def serialize(self) -> JsonDict: return {'.class': 'TupleType', 'items': [t.serialize() for t in self.items], @@ -965,6 +1041,21 @@ def __init__(self, items: 'OrderedDict[str, Type]', required_keys: Set[str], def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_typeddict_type(self) + def __hash__(self) -> int: + return hash((frozenset(self.items.items()), self.fallback, + frozenset(self.required_keys))) + + def __eq__(self, other: object) -> bool: + if isinstance(other, TypedDictType): + if frozenset(self.items.keys()) != frozenset(other.items.keys()): + return False + for (_, left_item_type, right_item_type) in self.zip(other): + if not left_item_type == right_item_type: + return False + return self.fallback == other.fallback and self.required_keys == other.required_keys + else: + return NotImplemented + def serialize(self) -> JsonDict: return {'.class': 'TypedDictType', 'items': [[n, t.serialize()] for (n, t) in self.items.items()], @@ -1062,6 +1153,14 @@ def __init__(self, items: List[Type], line: int = -1, column: int = -1) -> None: self.can_be_false = any(item.can_be_false for item in items) super().__init__(line, column) + def __hash__(self) -> int: + return hash(frozenset(self.items)) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, UnionType): + return NotImplemented + return frozenset(self.items) == frozenset(other.items) + @staticmethod def make_union(items: List[Type], line: int = -1, column: int = -1) -> Type: if len(items) > 1: @@ -1258,6 +1357,14 @@ def make_normalized(item: Type, *, line: int = -1, column: int = -1) -> Type: def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_type_type(self) + def __hash__(self) -> int: + return hash(self.item) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TypeType): + return NotImplemented + return self.item == other.item + def serialize(self) -> JsonDict: return {'.class': 'TypeType', 'item': self.item.serialize()} diff --git a/test-data/unit/check-abstract.test b/test-data/unit/check-abstract.test index 9a0d4afc5971..cd8cbe43b595 100644 --- a/test-data/unit/check-abstract.test +++ b/test-data/unit/check-abstract.test @@ -174,8 +174,8 @@ def f(cls: Type[A]) -> A: def g() -> A: return A() # E: Cannot instantiate abstract class 'A' with abstract attribute 'm' -f(A) # E: Only non-abstract class can be given where 'Type[__main__.A]' is expected -f(B) # E: Only non-abstract class can be given where 'Type[__main__.A]' is expected +f(A) # E: Only concrete class can be given where 'Type[A]' is expected +f(B) # E: Only concrete class can be given where 'Type[A]' is expected f(C) # OK x: Type[B] f(x) # OK @@ -200,7 +200,7 @@ Alias = A GoodAlias = C Alias() # E: Cannot instantiate abstract class 'A' with abstract attribute 'm' GoodAlias() -f(Alias) # E: Only non-abstract class can be given where 'Type[__main__.A]' is expected +f(Alias) # E: Only concrete class can be given where 'Type[A]' is expected f(GoodAlias) [out] @@ -218,14 +218,14 @@ class C(B): var: Type[A] var() -var = A # E: Can only assign non-abstract classes to a variable of type 'Type[__main__.A]' -var = B # E: Can only assign non-abstract classes to a variable of type 'Type[__main__.A]' +var = A # E: Can only assign concrete classes to a variable of type 'Type[A]' +var = B # E: Can only assign concrete classes to a variable of type 'Type[A]' var = C # OK var_old = None # type: Type[A] # Old syntax for variable annotations var_old() -var_old = A # E: Can only assign non-abstract classes to a variable of type 'Type[__main__.A]' -var_old = B # E: Can only assign non-abstract classes to a variable of type 'Type[__main__.A]' +var_old = A # E: Can only assign concrete classes to a variable of type 'Type[A]' +var_old = B # E: Can only assign concrete classes to a variable of type 'Type[A]' var_old = C # OK [out] diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 7d7f0fa2ddd6..87a25a8e7858 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -3152,19 +3152,20 @@ def f(TB: Type[B]): [case testMetaclassIterable] from typing import Iterable, Iterator -class BadMeta(type): +class ImplicitMeta(type): def __iter__(self) -> Iterator[int]: yield 1 -class Bad(metaclass=BadMeta): pass +class Implicit(metaclass=ImplicitMeta): pass -for _ in Bad: pass # E: Iterable expected +for _ in Implicit: pass +reveal_type(list(Implicit)) # E: Revealed type is 'builtins.list[builtins.int*]' -class GoodMeta(type, Iterable[int]): +class ExplicitMeta(type, Iterable[int]): def __iter__(self) -> Iterator[int]: yield 1 -class Good(metaclass=GoodMeta): pass -for _ in Good: pass -reveal_type(list(Good)) # E: Revealed type is 'builtins.list[builtins.int*]' +class Explicit(metaclass=ExplicitMeta): pass +for _ in Explicit: pass +reveal_type(list(Explicit)) # E: Revealed type is 'builtins.list[builtins.int*]' [builtins fixtures/list.pyi] diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 0d962533eabd..4ad8620d53d4 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -1201,7 +1201,7 @@ reveal_type(D[str, int]().c()) # E: Revealed type is 'builtins.str*' from typing import TypeVar, Generic T = TypeVar('T') -class A(Generic[T, T]): # E: Duplicate type variables in Generic[...] +class A(Generic[T, T]): # E: Duplicate type variables in Generic[...] or Protocol[...] pass a = A[int]() @@ -1218,7 +1218,7 @@ class A(Generic[T]): class B(Generic[T]): pass -class C(A[T], B[S], Generic[T]): # E: If Generic[...] is present it should list all type variables +class C(A[T], B[S], Generic[T]): # E: If Generic[...] or Protocol[...] is present it should list all type variables pass c = C[int, str]() diff --git a/test-data/unit/check-incremental.test b/test-data/unit/check-incremental.test index 557d42a1ec42..916a17752c7d 100644 --- a/test-data/unit/check-incremental.test +++ b/test-data/unit/check-incremental.test @@ -1501,6 +1501,99 @@ class MyClass: [rechecked] [stale] +[case testIncrementalWorksWithBasicProtocols] +import a +[file a.py] +from b import P + +x: int +y: P[int] +x = y.meth() + +class C: + def meth(self) -> int: + pass +y = C() + +[file a.py.2] +from b import P + +x: str +y: P[str] +x = y.meth() + +class C: + def meth(self) -> str: + pass +y = C() +[file b.py] +from typing import Protocol, TypeVar + +T = TypeVar('T', covariant=True) +class P(Protocol[T]): + def meth(self) -> T: + pass + +[case testIncrementalSwitchFromNominalToStructural] +import a +[file a.py] +from b import B, fun +class C(B): + def x(self) -> int: pass + def y(self) -> int: pass +fun(C()) + +[file b.py] +from typing import Protocol +class B: + def x(self) -> float: pass +def fun(arg: B) -> None: + arg.x() + +[file b.py.2] +from typing import Protocol +class B(Protocol): + def x(self) -> float: pass +def fun(arg: B) -> None: + arg.x() + +[file a.py.3] +from b import fun +class C: + def x(self) -> int: pass + def y(self) -> int: pass +fun(C()) +[out1] +[out2] +[out3] + +[case testIncrementalSwitchFromStructuralToNominal] +import a +[file a.py] +from b import fun +class C: + def x(self) -> int: pass + def y(self) -> int: pass +fun(C()) + +[file b.py] +from typing import Protocol +class B(Protocol): + def x(self) -> float: pass +def fun(arg: B) -> None: + arg.x() + +[file b.py.2] +from typing import Protocol +class B: + def x(self) -> float: pass +def fun(arg: B) -> None: + arg.x() + +[out1] +[out2] +tmp/a.py:5: error: Argument 1 to "fun" has incompatible type "C"; expected "B" + [case testIncrementalWorksWithNamedTuple] import foo diff --git a/test-data/unit/check-lists.test b/test-data/unit/check-lists.test index c9c67e80d4fb..c575f4b76da4 100644 --- a/test-data/unit/check-lists.test +++ b/test-data/unit/check-lists.test @@ -64,7 +64,7 @@ class C: pass [case testListWithStarExpr] (x, *a) = [1, 2, 3] a = [1, *[2, 3]] -reveal_type(a) # E: Revealed type is 'builtins.list[builtins.int]' +reveal_type(a) # E: Revealed type is 'builtins.list[builtins.int*]' b = [0, *a] reveal_type(b) # E: Revealed type is 'builtins.list[builtins.int*]' c = [*a, 0] diff --git a/test-data/unit/check-newtype.test b/test-data/unit/check-newtype.test index 4c7e1468a167..c9e01edf3ef4 100644 --- a/test-data/unit/check-newtype.test +++ b/test-data/unit/check-newtype.test @@ -332,6 +332,22 @@ B = NewType('B', A) class C(B): pass # E: Cannot subclass NewType [out] +[case testCannotUseNewTypeWithProtocols] +from typing import Protocol, NewType + +class P(Protocol): + attr: int +class D: + attr: int + +C = NewType('C', P) # E: NewType cannot be used with protocol classes + +x: C = C(D()) # We still accept this, treating 'C' as non-protocol subclass. +reveal_type(x.attr) # E: Revealed type is 'builtins.int' +x.bad_attr # E: "C" has no attribute "bad_attr" +C(1) # E: Argument 1 to "C" has incompatible type "int"; expected "P" +[out] + [case testNewTypeAny] from typing import NewType Any = NewType('Any', int) diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test new file mode 100644 index 000000000000..ba190988be61 --- /dev/null +++ b/test-data/unit/check-protocols.test @@ -0,0 +1,2120 @@ +-- Simple protocol types +-- --------------------- + +[case testCannotInstantiateProtocol] +from typing import Protocol + +class P(Protocol): + def meth(self) -> None: + pass + +P() # E: Cannot instantiate protocol class "P" + +[case testSimpleProtocolOneMethod] +from typing import Protocol + +class P(Protocol): + def meth(self) -> None: + pass + +class B: pass +class C: + def meth(self) -> None: + pass + +x: P +def fun(x: P) -> None: + x.meth() + x.meth(x) # E: Too many arguments for "meth" of "P" + x.bad # E: "P" has no attribute "bad" + +x = C() +x = B() # E: Incompatible types in assignment (expression has type "B", variable has type "P") + +fun(C()) +fun(B()) # E: Argument 1 to "fun" has incompatible type "B"; expected "P" + +def fun2() -> P: + return C() +def fun3() -> P: + return B() # E: Incompatible return value type (got "B", expected "P") + +[case testSimpleProtocolOneAbstractMethod] +from typing import Protocol +from abc import abstractmethod + +class P(Protocol): + @abstractmethod + def meth(self) -> None: + pass + +class B: pass +class C: + def meth(self) -> None: + pass +class D(B): + def meth(self) -> None: + pass + +x: P +def fun(x: P) -> None: + x.meth() + x.meth(x) # E: Too many arguments for "meth" of "P" + x.bad # E: "P" has no attribute "bad" + +x = C() +x = D() +x = B() # E: Incompatible types in assignment (expression has type "B", variable has type "P") +fun(C()) +fun(D()) +fun(B()) # E: Argument 1 to "fun" has incompatible type "B"; expected "P" +fun(x) + +[case testProtocolMethodBodies] +from typing import Protocol, List + +class P(Protocol): + def meth(self) -> int: + return 'no way' # E: Incompatible return value type (got "str", expected "int") + +# explicit ellipsis is OK in protocol methods +class P2(Protocol): + def meth2(self) -> List[int]: + ... +[builtins fixtures/list.pyi] + +[case testSimpleProtocolOneMethodOverride] +from typing import Protocol, Union + +class P(Protocol): + def meth(self) -> Union[int, str]: + pass +class SubP(P, Protocol): + def meth(self) -> int: + pass + +class B: pass +class C: + def meth(self) -> int: + pass +z: P +x: SubP +def fun(x: SubP) -> str: + x.bad # E: "SubP" has no attribute "bad" + return x.meth() # E: Incompatible return value type (got "int", expected "str") + +z = x +x = C() +x = B() # E: Incompatible types in assignment (expression has type "B", variable has type "SubP") + +reveal_type(fun(C())) # E: Revealed type is 'builtins.str' +fun(B()) # E: Argument 1 to "fun" has incompatible type "B"; expected "SubP" + +[case testSimpleProtocolTwoMethodsMerge] +from typing import Protocol + +class P1(Protocol): + def meth1(self) -> int: + pass +class P2(Protocol): + def meth2(self) -> str: + pass +class P(P1, P2, Protocol): pass + +class B: pass +class C1: + def meth1(self) -> int: + pass +class C2(C1): + def meth2(self) -> str: + pass +class C: + def meth1(self) -> int: + pass + def meth2(self) -> str: + pass + +class AnotherP(Protocol): + def meth1(self) -> int: + pass + def meth2(self) -> str: + pass + +x: P +reveal_type(x.meth1()) # E: Revealed type is 'builtins.int' +reveal_type(x.meth2()) # E: Revealed type is 'builtins.str' + +c: C +c1: C1 +c2: C2 +y: AnotherP + +x = c +x = B() # E: Incompatible types in assignment (expression has type "B", variable has type "P") +x = c1 # E: Incompatible types in assignment (expression has type "C1", variable has type "P") \ + # N: 'C1' is missing following 'P' protocol member: \ + # N: meth2 +x = c2 +x = y +y = x + +[case testSimpleProtocolTwoMethodsExtend] +from typing import Protocol + +class P1(Protocol): + def meth1(self) -> int: + pass +class P2(P1, Protocol): + def meth2(self) -> str: + pass + +class Cbad: + def meth1(self) -> int: + pass + +class C: + def meth1(self) -> int: + pass + def meth2(self) -> str: + pass + +x: P2 +reveal_type(x.meth1()) # E: Revealed type is 'builtins.int' +reveal_type(x.meth2()) # E: Revealed type is 'builtins.str' + +x = C() # OK +x = Cbad() # E: Incompatible types in assignment (expression has type "Cbad", variable has type "P2") \ + # N: 'Cbad' is missing following 'P2' protocol member: \ + # N: meth2 + +[case testProtocolMethodVsAttributeErrors] +from typing import Protocol + +class P(Protocol): + def meth(self) -> int: + pass +class C: + meth: int +x: P = C() # E: Incompatible types in assignment (expression has type "C", variable has type "P") \ + # N: Following member(s) of "C" have conflicts: \ + # N: meth: expected Callable[[], int], got "int" + +[case testProtocolMethodVsAttributeErrors2] +from typing import Protocol + +class P(Protocol): + @property + def meth(self) -> int: + pass +class C: + def meth(self) -> int: + pass +x: P = C() # E: Incompatible types in assignment (expression has type "C", variable has type "P") \ + # N: Following member(s) of "C" have conflicts: \ + # N: meth: expected "int", got Callable[[], int] +[builtins fixtures/property.pyi] + +[case testCannotAssignNormalToProtocol] +from typing import Protocol + +class P(Protocol): + def meth(self) -> int: + pass +class C: + def meth(self) -> int: + pass + +x: C +y: P +x = y # E: Incompatible types in assignment (expression has type "P", variable has type "C") + +[case testIndependentProtocolSubtyping] +from typing import Protocol + +class P1(Protocol): + def meth(self) -> int: + pass +class P2(Protocol): + def meth(self) -> int: + pass + +x1: P1 +x2: P2 + +x1 = x2 +x2 = x1 + +def f1(x: P1) -> None: pass +def f2(x: P2) -> None: pass + +f1(x2) +f2(x1) + +[case testNoneDisablesProtocolImplementation] +from typing import Protocol + +class MyHashable(Protocol): + def __my_hash__(self) -> int: + return 0 + +class C: + __my_hash__ = None + +var: MyHashable = C() # E: Incompatible types in assignment (expression has type "C", variable has type "MyHashable") + +[case testNoneDisablesProtocolSubclassingWithStrictOptional] +# flags: --strict-optional +from typing import Protocol + +class MyHashable(Protocol): + def __my_hash__(self) -> int: + return 0 + +class C(MyHashable): + __my_hash__ = None # E: Incompatible types in assignment \ +(expression has type None, base class "MyHashable" defined the type as Callable[[MyHashable], int]) + +[case testProtocolsWithNoneAndStrictOptional] +# flags: --strict-optional +from typing import Protocol +class P(Protocol): + x = 0 # type: int + +class C: + x = None + +x: P = C() # Error! +def f(x: P) -> None: pass +f(C()) # Error! +[out] +main:9: error: Incompatible types in assignment (expression has type "C", variable has type "P") +main:9: note: Following member(s) of "C" have conflicts: +main:9: note: x: expected "int", got None +main:11: error: Argument 1 to "f" has incompatible type "C"; expected "P" +main:11: note: Following member(s) of "C" have conflicts: +main:11: note: x: expected "int", got None + +-- Semanal errors in protocol types +-- -------------------------------- + +[case testBasicSemanalErrorsInProtocols] +from typing import Protocol, Generic, TypeVar, Iterable + +T = TypeVar('T', covariant=True) +S = TypeVar('S', covariant=True) + +class P1(Protocol[T, T]): # E: Duplicate type variables in Generic[...] or Protocol[...] + def meth(self) -> T: + pass + +class P2(Protocol[T], Protocol[S]): # E: Only single Generic[...] or Protocol[...] can be in bases + def meth(self) -> T: + pass + +class P3(Protocol[T], Generic[S]): # E: Only single Generic[...] or Protocol[...] can be in bases + def meth(self) -> T: + pass + +class P4(Protocol[T]): + attr: Iterable[S] # E: Invalid type "__main__.S" + +class P5(Iterable[S], Protocol[T]): # E: If Generic[...] or Protocol[...] is present it should list all type variables + def meth(self) -> T: + pass + +[case testProhibitSelfDefinitionInProtocols] +from typing import Protocol + +class P(Protocol): + def __init__(self, a: int) -> None: + self.a = a # E: Protocol members cannot be defined via assignment to self \ + # E: "P" has no attribute "a" + +class B: pass +class C: + def __init__(self, a: int) -> None: + pass + +x: P +x = B() +# The above has an incompatible __init__, but mypy ignores this for nominal subtypes? +x = C(1) + +class P2(Protocol): + a: int + def __init__(self) -> None: + self.a = 1 + +class B2(P2): + a: int + +x2: P2 = B2() # OK + +[case testProtocolAndRuntimeAreDefinedAlsoInTypingExtensions] +from typing_extensions import Protocol, runtime + +@runtime +class P(Protocol): + def meth(self) -> int: + pass + +x: object +if isinstance(x, P): + reveal_type(x) # E: Revealed type is '__main__.P' + reveal_type(x.meth()) # E: Revealed type is 'builtins.int' + +class C: + def meth(self) -> int: + pass + +z: P = C() +[builtins fixtures/dict.pyi] + +[case testProtocolsCannotInheritFromNormal] +from typing import Protocol + +class C: pass +class D: pass + +class P(C, Protocol): # E: All bases of a protocol must be protocols + attr: int + +class P2(P, D, Protocol): # E: All bases of a protocol must be protocols + pass + +P2() # E: Cannot instantiate abstract class 'P2' with abstract attribute 'attr' +p: P2 +reveal_type(p.attr) # E: Revealed type is 'builtins.int' + +-- Generic protocol types +-- ---------------------- + +[case testGenericMethodWithProtocol] +from typing import Protocol, TypeVar +T = TypeVar('T') + +class P(Protocol): + def meth(self, x: int) -> int: + return x +class C: + def meth(self, x: T) -> T: + return x + +x: P = C() + +[case testGenericMethodWithProtocol2] +from typing import Protocol, TypeVar +T = TypeVar('T') + +class P(Protocol): + def meth(self, x: T) -> T: + return x +class C: + def meth(self, x: int) -> int: + return x + +x: P = C() +[out] +main:11: error: Incompatible types in assignment (expression has type "C", variable has type "P") +main:11: note: Following member(s) of "C" have conflicts: +main:11: note: Expected: +main:11: note: def [T] meth(self, x: T) -> T +main:11: note: Got: +main:11: note: def meth(self, x: int) -> int + +[case testAutomaticProtocolVariance] +from typing import TypeVar, Protocol + +T = TypeVar('T') + +# In case of these errors we proceed with declared variance. +class Pco(Protocol[T]): # E: Invariant type variable 'T' used in protocol where covariant one is expected + def meth(self) -> T: + pass +class Pcontra(Protocol[T]): # E: Invariant type variable 'T' used in protocol where contravariant one is expected + def meth(self, x: T) -> None: + pass +class Pinv(Protocol[T]): + attr: T + +class A: pass +class B(A): pass + +x1: Pco[B] +y1: Pco[A] +x1 = y1 # E: Incompatible types in assignment (expression has type Pco[A], variable has type Pco[B]) +y1 = x1 # E: Incompatible types in assignment (expression has type Pco[B], variable has type Pco[A]) + +x2: Pcontra[B] +y2: Pcontra[A] +y2 = x2 # E: Incompatible types in assignment (expression has type Pcontra[B], variable has type Pcontra[A]) +x2 = y2 # E: Incompatible types in assignment (expression has type Pcontra[A], variable has type Pcontra[B]) + +x3: Pinv[B] +y3: Pinv[A] +y3 = x3 # E: Incompatible types in assignment (expression has type Pinv[B], variable has type Pinv[A]) +x3 = y3 # E: Incompatible types in assignment (expression has type Pinv[A], variable has type Pinv[B]) + +[case testProtocolVarianceWithCallableAndList] +from typing import Protocol, TypeVar, Callable, List +T = TypeVar('T') +S = TypeVar('S') +T_co = TypeVar('T_co', covariant=True) + +class P(Protocol[T, S]): # E: Invariant type variable 'T' used in protocol where covariant one is expected \ + # E: Invariant type variable 'S' used in protocol where contravariant one is expected + def fun(self, callback: Callable[[T], S]) -> None: pass + +class P2(Protocol[T_co]): # E: Covariant type variable 'T_co' used in protocol where invariant one is expected + lst: List[T_co] +[builtins fixtures/list.pyi] + +[case testProtocolVarianceWithUnusedVariable] +from typing import Protocol, TypeVar +T = TypeVar('T') + +class P(Protocol[T]): # E: Invariant type variable 'T' used in protocol where covariant one is expected + attr: int + +[case testGenericProtocolsInference1] +from typing import Protocol, Sequence, TypeVar + +T = TypeVar('T', covariant=True) + +class Closeable(Protocol[T]): + def close(self) -> T: + pass + +class F: + def close(self) -> int: + return 0 + +def close(arg: Closeable[T]) -> T: + return arg.close() + +def close_all(args: Sequence[Closeable[T]]) -> T: + for arg in args: + arg.close() + return args[0].close() + +arg: Closeable[int] + +reveal_type(close(F())) # E: Revealed type is 'builtins.int*' +reveal_type(close(arg)) # E: Revealed type is 'builtins.int*' +reveal_type(close_all([F()])) # E: Revealed type is 'builtins.int*' +reveal_type(close_all([arg])) # E: Revealed type is 'builtins.int*' +[builtins fixtures/isinstancelist.pyi] + +[case testProtocolGenericInference2] +from typing import Generic, TypeVar, Protocol +T = TypeVar('T') +S = TypeVar('S') + +class P(Protocol[T, S]): + x: T + y: S + +class C: + x: int + y: int + +def fun3(x: P[T, T]) -> T: + pass +reveal_type(fun3(C())) # E: Revealed type is 'builtins.int*' + +[case testProtocolGenericInferenceCovariant] +from typing import Generic, TypeVar, Protocol +T = TypeVar('T', covariant=True) +S = TypeVar('S', covariant=True) +U = TypeVar('U') + +class P(Protocol[T, S]): + def x(self) -> T: pass + def y(self) -> S: pass + +class C: + def x(self) -> int: pass + def y(self) -> int: pass + +def fun4(x: U, y: P[U, U]) -> U: + pass +reveal_type(fun4('a', C())) # E: Revealed type is 'builtins.object*' + +[case testUnrealtedGenericProtolsEquivalent] +from typing import TypeVar, Protocol +T = TypeVar('T') + +class PA(Protocol[T]): + attr: int + def meth(self) -> T: pass + def other(self, arg: T) -> None: pass +class PB(Protocol[T]): # exactly the same as above + attr: int + def meth(self) -> T: pass + def other(self, arg: T) -> None: pass + +def fun(x: PA[T]) -> PA[T]: + y: PB[T] = x + z: PB[T] + return z + +x: PA +y: PB +x = y +y = x + +xi: PA[int] +yi: PB[int] +xi = yi +yi = xi + +[case testGenericSubProtocols] +from typing import TypeVar, Protocol, Tuple, Generic + +T = TypeVar('T') +S = TypeVar('S') + +class P1(Protocol[T]): + attr1: T +class P2(P1[T], Protocol[T, S]): + attr2: Tuple[T, S] + +class C: + def __init__(self, a1: int, a2: Tuple[int, int]) -> None: + self.attr1 = a1 + self.attr2 = a2 + +c: C +var: P2[int, int] = c +var2: P2[int, str] = c # E: Incompatible types in assignment (expression has type "C", variable has type P2[int, str]) \ + # N: Following member(s) of "C" have conflicts: \ + # N: attr2: expected "Tuple[int, str]", got "Tuple[int, int]" + +class D(Generic[T]): + attr1: T +class E(D[T]): + attr2: Tuple[T, T] + +def f(x: T) -> T: + z: P2[T, T] = E[T]() + y: P2[T, T] = D[T]() # E: Incompatible types in assignment (expression has type D[T], variable has type P2[T, T]) \ + # N: 'D' is missing following 'P2' protocol member: \ + # N: attr2 + return x +[builtins fixtures/isinstancelist.pyi] + +[case testGenericSubProtocolsExtensionInvariant] +from typing import TypeVar, Protocol, Union + +T = TypeVar('T') +S = TypeVar('S') + +class P1(Protocol[T]): + attr1: T +class P2(Protocol[T]): + attr2: T +class P(P1[T], P2[S], Protocol): + pass + +class C: + attr1: int + attr2: str + +class A: + attr1: A +class B: + attr2: B +class D(A, B): pass + +x: P = D() # Same as P[Any, Any] + +var: P[Union[int, P], Union[P, str]] = C() # E: Incompatible types in assignment (expression has type "C", variable has type P[Union[int, P[Any, Any]], Union[P[Any, Any], str]]) \ + # N: Following member(s) of "C" have conflicts: \ + # N: attr1: expected "Union[int, P[Any, Any]]", got "int" \ + # N: attr2: expected "Union[P[Any, Any], str]", got "str" + +[case testGenericSubProtocolsExtensionCovariant] +from typing import TypeVar, Protocol, Union + +T = TypeVar('T', covariant=True) +S = TypeVar('S', covariant=True) + +class P1(Protocol[T]): + def attr1(self) -> T: pass +class P2(Protocol[T]): + def attr2(self) -> T: pass +class P(P1[T], P2[S], Protocol): + pass + +class C: + def attr1(self) -> int: pass + def attr2(self) -> str: pass + +var: P[Union[int, P], Union[P, str]] = C() # OK for covariant +var2: P[Union[str, P], Union[P, int]] = C() +[out] +main:18: error: Incompatible types in assignment (expression has type "C", variable has type P[Union[str, P[Any, Any]], Union[P[Any, Any], int]]) +main:18: note: Following member(s) of "C" have conflicts: +main:18: note: Expected: +main:18: note: def attr1(self) -> Union[str, P[Any, Any]] +main:18: note: Got: +main:18: note: def attr1(self) -> int +main:18: note: Expected: +main:18: note: def attr2(self) -> Union[P[Any, Any], int] +main:18: note: Got: +main:18: note: def attr2(self) -> str + +[case testSelfTypesWithProtocolsBehaveAsWithNominal] +from typing import Protocol, TypeVar + +T = TypeVar('T', bound=Shape) +class Shape(Protocol): + def combine(self: T, other: T) -> T: + pass + +class NonProtoShape: + def combine(self: T, other: T) -> T: + pass +class Circle: + def combine(self: T, other: Shape) -> T: + pass +class Triangle: + def combine(self, other: Shape) -> Shape: + pass +class Bad: + def combine(self, other: int) -> str: + pass + +def f(s: Shape) -> None: pass +s: Shape + +f(NonProtoShape()) +f(Circle()) +s = Triangle() +s = Bad() + +n2: NonProtoShape = s +[out] +main:26: error: Incompatible types in assignment (expression has type "Triangle", variable has type "Shape") +main:26: note: Following member(s) of "Triangle" have conflicts: +main:26: note: Expected: +main:26: note: def combine(self, other: Triangle) -> Triangle +main:26: note: Got: +main:26: note: def combine(self, other: Shape) -> Shape +main:27: error: Incompatible types in assignment (expression has type "Bad", variable has type "Shape") +main:27: note: Following member(s) of "Bad" have conflicts: +main:27: note: Expected: +main:27: note: def combine(self, other: Bad) -> Bad +main:27: note: Got: +main:27: note: def combine(self, other: int) -> str +main:29: error: Incompatible types in assignment (expression has type "Shape", variable has type "NonProtoShape") + +[case testBadVarianceInProtocols] +from typing import Protocol, TypeVar + +T_co = TypeVar('T_co', covariant=True) +T_contra = TypeVar('T_contra', contravariant=True) + +class Proto(Protocol[T_co, T_contra]): # type: ignore + def one(self, x: T_co) -> None: # E: Cannot use a covariant type variable as a parameter + pass + def other(self) -> T_contra: # E: Cannot use a contravariant type variable as return type + pass + +# Check that we respect user overrides of variance after the errors are reported +x: Proto[int, float] +y: Proto[float, int] +y = x # OK +[builtins fixtures/list.pyi] + +[case testSubtleBadVarianceInProtocols] +from typing import Protocol, TypeVar, Iterable, Sequence + +T_co = TypeVar('T_co', covariant=True) +T_contra = TypeVar('T_contra', contravariant=True) + +class Proto(Protocol[T_co, T_contra]): # E: Covariant type variable 'T_co' used in protocol where contravariant one is expected \ + # E: Contravariant type variable 'T_contra' used in protocol where covariant one is expected + def one(self, x: Iterable[T_co]) -> None: + pass + def other(self) -> Sequence[T_contra]: + pass + +# Check that we respect user overrides of variance after the errors are reported +x: Proto[int, float] +y: Proto[float, int] +y = x # OK +[builtins fixtures/list.pyi] + +-- Recursive protocol types +-- ------------------------ + +[case testRecursiveProtocols1] +from typing import Protocol, Sequence, List, Generic, TypeVar + +T = TypeVar('T') + +class Traversable(Protocol): + @property + def leaves(self) -> Sequence[Traversable]: pass + +class C: pass + +class D(Generic[T]): + leaves: List[D[T]] + +t: Traversable +t = D[int]() # OK +t = C() # E: Incompatible types in assignment (expression has type "C", variable has type "Traversable") +[builtins fixtures/list.pyi] + +[case testRecursiveProtocols2] +from typing import Protocol, TypeVar + +T = TypeVar('T') +class Linked(Protocol[T]): + val: T + def next(self) -> Linked[T]: pass + +class L: + val: int + def next(self) -> L: pass + +def last(seq: Linked[T]) -> T: + pass + +reveal_type(last(L())) # E: Revealed type is 'builtins.int*' +[builtins fixtures/list.pyi] + +[case testRecursiveProtocolSubtleMismatch] +from typing import Protocol, TypeVar + +T = TypeVar('T') +class Linked(Protocol[T]): + val: T + def next(self) -> Linked[T]: pass +class L: + val: int + def next(self) -> int: pass + +def last(seq: Linked[T]) -> T: + pass +last(L()) # E: Argument 1 to "last" has incompatible type "L"; expected Linked[] + +[case testMutuallyRecursiveProtocols] +from typing import Protocol, Sequence, List + +class P1(Protocol): + @property + def attr1(self) -> Sequence[P2]: pass +class P2(Protocol): + @property + def attr2(self) -> Sequence[P1]: pass + +class C: pass +class A: + attr1: List[B] +class B: + attr2: List[A] + +t: P1 +t = A() # OK +t = B() # E: Incompatible types in assignment (expression has type "B", variable has type "P1") +t = C() # E: Incompatible types in assignment (expression has type "C", variable has type "P1") +[builtins fixtures/list.pyi] + +[case testMutuallyRecursiveProtocolsTypesWithSubteMismatch] +from typing import Protocol, Sequence, List + +class P1(Protocol): + @property + def attr1(self) -> Sequence[P2]: pass +class P2(Protocol): + @property + def attr2(self) -> Sequence[P1]: pass + +class C: pass +class A: + attr1: List[B] +class B: + attr2: List[C] + +t: P1 +t = A() # E: Incompatible types in assignment (expression has type "A", variable has type "P1") \ + # N: Following member(s) of "A" have conflicts: \ + # N: attr1: expected Sequence[P2], got List[B] +[builtins fixtures/list.pyi] + +[case testMutuallyRecursiveProtocolsTypesWithSubteMismatchWriteable] +from typing import Protocol + +class P1(Protocol): + @property + def attr1(self) -> P2: pass +class P2(Protocol): + attr2: P1 + +class A: + attr1: B +class B: + attr2: A + +x: P1 = A() # E: Incompatible types in assignment (expression has type "A", variable has type "P1") \ + # N: Following member(s) of "A" have conflicts: \ + # N: attr1: expected "P2", got "B" +[builtins fixtures/property.pyi] + +-- FIXME: things like this should work +[case testWeirdRecursiveInferenceForProtocols-skip] +from typing import Protocol, TypeVar, Generic +T_co = TypeVar('T_co', covariant=True) +T = TypeVar('T') + +class P(Protocol[T_co]): + def meth(self) -> P[T_co]: pass + +class C(Generic[T]): + def meth(self) -> C[T]: pass + +x: C[int] +def f(arg: P[T]) -> T: pass +reveal_type(f(x)) #E: Revealed type is 'builtins.int*' + +-- @property, @classmethod and @staticmethod in protocol types +-- ----------------------------------------------------------- + +[case testCannotInstantiateAbstractMethodExplicitProtocolSubtypes] +from typing import Protocol +from abc import abstractmethod + +class P(Protocol): + @abstractmethod + def meth(self) -> int: + pass + +class A(P): + pass + +A() # E: Cannot instantiate abstract class 'A' with abstract attribute 'meth' + +class C(A): + def meth(self) -> int: + pass +class C2(P): + def meth(self) -> int: + pass + +C() +C2() + +[case testCannotInstantiateAbstractVariableExplicitProtocolSubtypes] +from typing import Protocol + +class P(Protocol): + attr: int + +class A(P): + pass + +A() # E: Cannot instantiate abstract class 'A' with abstract attribute 'attr' + +class C(A): + attr: int +class C2(P): + def __init__(self) -> None: + self.attr = 1 + +C() +C2() + +class P2(Protocol): + attr: int = 1 + +class B(P2): pass +B() # OK, attr is not abstract + +[case testClassVarsInProtocols] +from typing import Protocol, ClassVar + +class PInst(Protocol): + v: int + +class PClass(Protocol): + v: ClassVar[int] + +class CInst: + v: int + +class CClass: + v: ClassVar[int] + +x: PInst +y: PClass + +x = CInst() +x = CClass() # E: Incompatible types in assignment (expression has type "CClass", variable has type "PInst") \ + # N: Protocol member PInst.v expected instance variable, got class variable +y = CClass() +y = CInst() # E: Incompatible types in assignment (expression has type "CInst", variable has type "PClass") \ + # N: Protocol member PClass.v expected class variable, got instance variable + +[case testPropertyInProtocols] +from typing import Protocol + +class PP(Protocol): + @property + def attr(self) -> int: + pass + +class P(Protocol): + attr: int + +x: P +y: PP +y = x + +x2: P +y2: PP +x2 = y2 # E: Incompatible types in assignment (expression has type "PP", variable has type "P") \ + # N: Protocol member P.attr expected settable variable, got read-only attribute +[builtins fixtures/property.pyi] + +[case testSettablePropertyInProtocols] +from typing import Protocol + +class PPS(Protocol): + @property + def attr(self) -> int: + pass + @attr.setter + def attr(self, x: int) -> None: + pass + +class PP(Protocol): + @property + def attr(self) -> int: + pass + +class P(Protocol): + attr: int + +x: P +z: PPS +z = x + +x2: P +z2: PPS +x2 = z2 + +y3: PP +z3: PPS +y3 = z3 + +y4: PP +z4: PPS +z4 = y4 # E: Incompatible types in assignment (expression has type "PP", variable has type "PPS") \ + # N: Protocol member PPS.attr expected settable variable, got read-only attribute +[builtins fixtures/property.pyi] + +[case testStaticAndClassMethodsInProtocols] +from typing import Protocol, Type, TypeVar + +class P(Protocol): + def meth(self, x: int) -> str: + pass + +class PC(Protocol): + @classmethod + def meth(cls, x: int) -> str: + pass + +class B: + @staticmethod + def meth(x: int) -> str: + pass + +class C: + def meth(self, x: int) -> str: + pass + +x: P +x = C() +x = B() + +y: PC +y = B() +y = C() # E: Incompatible types in assignment (expression has type "C", variable has type "PC") \ + # N: Protocol member PC.meth expected class or static method +[builtins fixtures/classmethod.pyi] + +[case testOverloadedMethodsInProtocols] +from typing import overload, Protocol, Union + +class P(Protocol): + @overload + def f(self, x: int) -> int: pass + @overload + def f(self, x: str) -> str: pass + +class C: + def f(self, x: Union[int, str]) -> None: + pass +class D: + def f(self, x: int) -> None: + pass + +x: P = C() +x = D() +[out] +main:17: error: Incompatible types in assignment (expression has type "D", variable has type "P") +main:17: note: Following member(s) of "D" have conflicts: +main:17: note: Expected: +main:17: note: @overload +main:17: note: def f(self, x: int) -> int +main:17: note: @overload +main:17: note: def f(self, x: str) -> str +main:17: note: Got: +main:17: note: def f(self, x: int) -> None + +[case testCannotInstantiateProtocolWithOverloadedUnimplementedMethod] +from typing import overload, Protocol + +class P(Protocol): + @overload + def meth(self, x: int) -> int: pass + @overload + def meth(self, x: str) -> bytes: pass +class C(P): + pass +C() # E: Cannot instantiate abstract class 'C' with abstract attribute 'meth' + +[case testCanUseOverloadedImplementationsInProtocols] +from typing import overload, Protocol, Union +class P(Protocol): + @overload + def meth(self, x: int) -> int: pass + @overload + def meth(self, x: str) -> bool: pass + def meth(self, x: Union[int, str]): + if isinstance(x, int): + return x + return True + +class C(P): + pass +x = C() +reveal_type(x.meth('hi')) # E: Revealed type is 'builtins.bool' +[builtins fixtures/isinstance.pyi] + +[case testProtocolsWithIdenticalOverloads] +from typing import overload, Protocol + +class PA(Protocol): + @overload + def meth(self, x: int) -> int: pass + @overload + def meth(self, x: str) -> bytes: pass +class PB(Protocol): # identical to above + @overload + def meth(self, x: int) -> int: pass + @overload + def meth(self, x: str) -> bytes: pass + +x: PA +y: PB +x = y +def fun(arg: PB) -> None: pass +fun(x) + +[case testProtocolsWithIncompatibleOverloads] +from typing import overload, Protocol + +class PA(Protocol): + @overload + def meth(self, x: int) -> int: pass + @overload + def meth(self, x: str) -> bytes: pass +class PB(Protocol): + @overload + def meth(self, x: int) -> int: pass + @overload + def meth(self, x: bytes) -> str: pass + +x: PA +y: PB +x = y +[out] +main:16: error: Incompatible types in assignment (expression has type "PB", variable has type "PA") +main:16: note: Following member(s) of "PB" have conflicts: +main:16: note: Expected: +main:16: note: @overload +main:16: note: def meth(self, x: int) -> int +main:16: note: @overload +main:16: note: def meth(self, x: str) -> bytes +main:16: note: Got: +main:16: note: @overload +main:16: note: def meth(self, x: int) -> int +main:16: note: @overload +main:16: note: def meth(self, x: bytes) -> str + +-- Join and meet with protocol types +-- --------------------------------- + +[case testJoinProtocolWithProtocol] +from typing import Protocol + +class P(Protocol): + attr: int +class P2(Protocol): + attr: int + attr2: str + +x: P +y: P2 + +l0 = [x, x] +l1 = [y, y] +l = [x, y] +reveal_type(l0) # E: Revealed type is 'builtins.list[__main__.P*]' +reveal_type(l1) # E: Revealed type is 'builtins.list[__main__.P2*]' +reveal_type(l) # E: Revealed type is 'builtins.list[__main__.P*]' +[builtins fixtures/list.pyi] + +[case testJoinOfIncompatibleProtocols] +from typing import Protocol + +class P(Protocol): + attr: int +class P2(Protocol): + attr2: str + +x: P +y: P2 +reveal_type([x, y]) # E: Revealed type is 'builtins.list[builtins.object*]' +[builtins fixtures/list.pyi] + +[case testJoinProtocolWithNormal] +from typing import Protocol + +class P(Protocol): + attr: int + +class C: + attr: int + +x: P +y: C + +l = [x, y] + +reveal_type(l) # E: Revealed type is 'builtins.list[__main__.P*]' +[builtins fixtures/list.pyi] + +[case testMeetProtocolWithProtocol] +from typing import Protocol, Callable, TypeVar + +class P(Protocol): + attr: int +class P2(Protocol): + attr: int + attr2: str + +T = TypeVar('T') +def f(x: Callable[[T, T], None]) -> T: pass +def g(x: P, y: P2) -> None: pass +reveal_type(f(g)) # E: Revealed type is '__main__.P2*' + +[case testMeetOfIncompatibleProtocols] +from typing import Protocol, Callable, TypeVar + +class P(Protocol): + attr: int +class P2(Protocol): + attr2: str + +T = TypeVar('T') +def f(x: Callable[[T, T], None]) -> T: pass +def g(x: P, y: P2) -> None: pass +x = f(g) # E: "f" does not return a value + +[case testMeetProtocolWithNormal] +from typing import Protocol, Callable, TypeVar + +class P(Protocol): + attr: int +class C: + attr: int + +T = TypeVar('T') +def f(x: Callable[[T, T], None]) -> T: pass +def g(x: P, y: C) -> None: pass +reveal_type(f(g)) # E: Revealed type is '__main__.C*' + +[case testInferProtocolFromProtocol] +from typing import Protocol, Sequence, TypeVar, Generic + +T = TypeVar('T') +class Box(Protocol[T]): + content: T +class Linked(Protocol[T]): + val: T + def next(self) -> Linked[T]: pass + +class L(Generic[T]): + val: Box[T] + def next(self) -> L[T]: pass + +def last(seq: Linked[T]) -> T: + pass + +reveal_type(last(L[int]())) # E: Revealed type is '__main__.Box*[builtins.int*]' +reveal_type(last(L[str]()).content) # E: Revealed type is 'builtins.str*' + +[case testOverloadOnProtocol] +from typing import overload, Protocol, runtime + +@runtime +class P1(Protocol): + attr1: int +class P2(Protocol): + attr2: str + +class C1: + attr1: int +class C2: + attr2: str +class C: pass + +@overload +def f(x: P1) -> int: ... +@overload +def f(x: P2) -> str: ... +def f(x): + if isinstance(x, P1): + return P1.attr1 + if isinstance(x, P2): # E: Only @runtime protocols can be used with instance and class checks + return P1.attr2 + +reveal_type(f(C1())) # E: Revealed type is 'builtins.int' +reveal_type(f(C2())) # E: Revealed type is 'builtins.str' +class D(C1, C2): pass # Compatible with both P1 and P2 +# FIXME: the below is not right, see #1322 +reveal_type(f(D())) # E: Revealed type is 'Any' +f(C()) # E: No overload variant of "f" matches argument types [__main__.C] +[builtins fixtures/isinstance.pyi] + +-- Unions of protocol types +-- ------------------------ + +[case testBasicUnionsOfProtocols] +from typing import Union, Protocol + +class P1(Protocol): + attr1: int +class P2(Protocol): + attr2: int + +class C1: + attr1: int +class C2: + attr2: int +class C(C1, C2): + pass + +class B: ... + +x: Union[P1, P2] + +x = C1() +x = C2() +x = C() +x = B() # E: Incompatible types in assignment (expression has type "B", variable has type "Union[P1, P2]") + +[case testUnionsOfNormalClassesWithProtocols] +from typing import Protocol, Union + +class P1(Protocol): + attr1: int +class P2(Protocol): + attr2: int + +class C1: + attr1: int +class C2: + attr2: int +class C(C1, C2): + pass + +class D1: + attr1: int + +def f1(x: P1) -> None: + pass +def f2(x: P2) -> None: + pass + +x: Union[C1, C2] +y: Union[C1, D1] +z: Union[C, D1] + +f1(x) # E: Argument 1 to "f1" has incompatible type "Union[C1, C2]"; expected "P1" +f1(y) +f1(z) +f2(x) # E: Argument 1 to "f2" has incompatible type "Union[C1, C2]"; expected "P2" +f2(z) # E: Argument 1 to "f2" has incompatible type "Union[C, D1]"; expected "P2" + +-- Type[] with protocol types +-- -------------------------- + +[case testInstantiationProtocolInTypeForFunctions] +from typing import Type, Protocol + +class P(Protocol): + def m(self) -> None: pass +class P1(Protocol): + def m(self) -> None: pass +class Pbad(Protocol): + def mbad(self) -> int: pass +class B(P): pass +class C: + def m(self) -> None: + pass + +def f(cls: Type[P]) -> P: + return cls() # OK +def g() -> P: + return P() # E: Cannot instantiate protocol class "P" + +f(P) # E: Only concrete class can be given where 'Type[P]' is expected +f(B) # OK +f(C) # OK +x: Type[P1] +xbad: Type[Pbad] +f(x) # OK +f(xbad) # E: Argument 1 to "f" has incompatible type Type[Pbad]; expected Type[P] + +[case testInstantiationProtocolInTypeForAliases] +from typing import Type, Protocol + +class P(Protocol): + def m(self) -> None: pass +class C: + def m(self) -> None: + pass + +def f(cls: Type[P]) -> P: + return cls() # OK + +Alias = P +GoodAlias = C +Alias() # E: Cannot instantiate protocol class "P" +GoodAlias() +f(Alias) # E: Only concrete class can be given where 'Type[P]' is expected +f(GoodAlias) + +[case testInstantiationProtocolInTypeForVariables] +from typing import Type, Protocol + +class P(Protocol): + def m(self) -> None: pass +class B(P): pass +class C: + def m(self) -> None: + pass + +var: Type[P] +var() +var = P # E: Can only assign concrete classes to a variable of type 'Type[P]' +var = B # OK +var = C # OK + +var_old = None # type: Type[P] # Old syntax for variable annotations +var_old() +var_old = P # E: Can only assign concrete classes to a variable of type 'Type[P]' +var_old = B # OK +var_old = C # OK + +[case testInstantiationProtocolInTypeForClassMethods] +from typing import Type, Protocol + +class Logger: + @staticmethod + def log(a: Type[C]): + pass +class C(Protocol): + @classmethod + def action(cls) -> None: + cls() #OK for classmethods + Logger.log(cls) #OK for classmethods +[builtins fixtures/classmethod.pyi] + +-- isinstance() with @runtime protocols +-- ------------------------------------ + +[case testSimpleRuntimeProtocolCheck] +from typing import Protocol, runtime + +@runtime # E: @runtime can only be used with protocol classes +class C: + pass + +class P(Protocol): + def meth(self) -> None: + pass + +@runtime +class R(Protocol): + def meth(self) -> int: + pass + +x: object + +if isinstance(x, P): # E: Only @runtime protocols can be used with instance and class checks + reveal_type(x) # E: Revealed type is '__main__.P' + +if isinstance(x, R): + reveal_type(x) # E: Revealed type is '__main__.R' + reveal_type(x.meth()) # E: Revealed type is 'builtins.int' +[builtins fixtures/isinstance.pyi] + +[case testRuntimeIterableProtocolCheck] +from typing import Iterable, List, Union + +x: Union[int, List[str]] + +if isinstance(x, Iterable): + reveal_type(x) # E: Revealed type is 'builtins.list[builtins.str]' +[builtins fixtures/isinstancelist.pyi] + +[case testConcreteClassesInProtocolsIsInstance] +from typing import Protocol, runtime, TypeVar, Generic + +T = TypeVar('T') + +@runtime +class P1(Protocol): + def meth1(self) -> int: + pass +@runtime +class P2(Protocol): + def meth2(self) -> int: + pass +@runtime +class P(P1, P2, Protocol): + pass + +class C1(Generic[T]): + def meth1(self) -> T: + pass +class C2: + def meth2(self) -> int: + pass +class C(C1[int], C2): pass + +c = C() +if isinstance(c, P1): + reveal_type(c) # E: Revealed type is '__main__.C' +else: + reveal_type(c) # Unreachable +if isinstance(c, P): + reveal_type(c) # E: Revealed type is '__main__.C' +else: + reveal_type(c) # Unreachable + +c1i: C1[int] +if isinstance(c1i, P1): + reveal_type(c1i) # E: Revealed type is '__main__.C1[builtins.int]' +else: + reveal_type(c1i) # Unreachable +if isinstance(c1i, P): + reveal_type(c1i) # Unreachable +else: + reveal_type(c1i) # E: Revealed type is '__main__.C1[builtins.int]' + +c1s: C1[str] +if isinstance(c1s, P1): + reveal_type(c1s) # Unreachable +else: + reveal_type(c1s) # E: Revealed type is '__main__.C1[builtins.str]' + +c2: C2 +if isinstance(c2, P): + reveal_type(c2) # Unreachable +else: + reveal_type(c2) # E: Revealed type is '__main__.C2' + +[builtins fixtures/isinstancelist.pyi] + +[case testConcreteClassesUnionInProtocolsIsInstance] +from typing import Protocol, runtime, TypeVar, Generic, Union + +T = TypeVar('T') + +@runtime +class P1(Protocol): + def meth1(self) -> int: + pass +@runtime +class P2(Protocol): + def meth2(self) -> int: + pass + +class C1(Generic[T]): + def meth1(self) -> T: + pass +class C2: + def meth2(self) -> int: + pass + +x: Union[C1[int], C2] +if isinstance(x, P1): + reveal_type(x) # E: Revealed type is '__main__.C1[builtins.int]' +else: + reveal_type(x) # E: Revealed type is '__main__.C2' + +if isinstance(x, P2): + reveal_type(x) # E: Revealed type is '__main__.C2' +else: + reveal_type(x) # E: Revealed type is '__main__.C1[builtins.int]' +[builtins fixtures/isinstancelist.pyi] + +-- Non-Instances and protocol types (Callable vs __call__ etc.) +-- ------------------------------------------------------------ + +[case testBasicTupleStructuralSubtyping] +from typing import Tuple, TypeVar, Protocol + +T = TypeVar('T', covariant=True) + +class MyProto(Protocol[T]): + def __len__(self) -> T: + pass + +t: Tuple[int, str] +def f(x: MyProto[int]) -> None: + pass +f(t) # OK + +y: MyProto[str] +y = t # E: Incompatible types in assignment (expression has type "Tuple[int, str]", variable has type MyProto[str]) +[builtins fixtures/isinstancelist.pyi] + +[case testBasicNamedTupleStructuralSubtyping] +from typing import NamedTuple, TypeVar, Protocol + +T = TypeVar('T', covariant=True) +S = TypeVar('S', covariant=True) + +class P(Protocol[T, S]): + @property + def x(self) -> T: pass + @property + def y(self) -> S: pass + +class N(NamedTuple): + x: int + y: str +class N2(NamedTuple): + x: int +class N3(NamedTuple): + x: int + y: int + +z: N +z3: N3 + +def fun(x: P[int, str]) -> None: + pass +def fun2(x: P[int, int]) -> None: + pass +def fun3(x: P[T, T]) -> T: + return x.x + +fun(z) +fun2(z) # E: Argument 1 to "fun2" has incompatible type "N"; expected P[int, int] \ + # N: Following member(s) of "N" have conflicts: \ + # N: y: expected "int", got "str" + +fun(N2(1)) # E: Argument 1 to "fun" has incompatible type "N2"; expected P[int, str] \ + # N: 'N2' is missing following 'P' protocol member: \ + # N: y + +reveal_type(fun3(z)) # E: Revealed type is 'builtins.object*' + +reveal_type(fun3(z3)) # E: Revealed type is 'builtins.int*' +[builtins fixtures/list.pyi] + +[case testBasicCallableStructuralSubtyping] +from typing import Callable, Generic, TypeVar + +def apply(f: Callable[[int], int], x: int) -> int: + return f(x) + +class Add5: + def __call__(self, x: int) -> int: + return x + 5 + +apply(Add5(), 5) + +T = TypeVar('T') +def apply_gen(f: Callable[[T], T]) -> T: + pass + +reveal_type(apply_gen(Add5())) # E: Revealed type is 'builtins.int*' +def apply_str(f: Callable[[str], int], x: str) -> int: + return f(x) +apply_str(Add5(), 'a') # E: Argument 1 to "apply_str" has incompatible type "Add5"; expected Callable[[str], int] \ + # N: 'Add5.__call__' has type 'Callable[[Arg(int, 'x')], int]' +[builtins fixtures/isinstancelist.pyi] + +[case testMoreComplexCallableStructuralSubtyping] +from mypy_extensions import Arg, VarArg +from typing import Protocol, Callable + +def call_soon(cb: Callable[[Arg(int, 'x'), VarArg(str)], int]): pass + +class Good: + def __call__(self, x: int, *rest: str) -> int: pass +class Bad1: + def __call__(self, x: int, *rest: int) -> int: pass +class Bad2: + def __call__(self, y: int, *rest: str) -> int: pass +call_soon(Good()) +call_soon(Bad1()) # E: Argument 1 to "call_soon" has incompatible type "Bad1"; expected Callable[[int, VarArg(str)], int] \ + # N: 'Bad1.__call__' has type 'Callable[[Arg(int, 'x'), VarArg(int)], int]' +call_soon(Bad2()) # E: Argument 1 to "call_soon" has incompatible type "Bad2"; expected Callable[[int, VarArg(str)], int] \ + # N: 'Bad2.__call__' has type 'Callable[[Arg(int, 'y'), VarArg(str)], int]' +[builtins fixtures/isinstancelist.pyi] + +[case testStructuralSupportForPartial] +from typing import Callable, TypeVar, Generic, Any + +T = TypeVar('T') + +class partial(Generic[T]): + def __init__(self, func: Callable[..., T], *args: Any) -> None: ... + def __call__(self, *args: Any) -> T: ... + +def inc(a: int, temp: str) -> int: + pass + +def foo(f: Callable[[int], T]) -> T: + return f(1) + +reveal_type(foo(partial(inc, 'temp'))) # E: Revealed type is 'builtins.int*' +[builtins fixtures/list.pyi] + +[case testStructuralInferenceForCallable] +from typing import Callable, TypeVar, Tuple + +T = TypeVar('T') +S = TypeVar('S') + +class Actual: + def __call__(self, arg: int) -> str: pass + +def fun(cb: Callable[[T], S]) -> Tuple[T, S]: pass +reveal_type(fun(Actual())) # E: Revealed type is 'Tuple[builtins.int*, builtins.str*]' +[builtins fixtures/tuple.pyi] + +-- Standard protocol types (SupportsInt, Sized, etc.) +-- -------------------------------------------------- + +-- More tests could be added for types from typing converted to protocols + +[case testBasicSizedProtocol] +from typing import Sized + +class Foo: + def __len__(self) -> int: + return 42 + +def bar(a: Sized) -> int: + return a.__len__() + +bar(Foo()) +bar((1, 2)) +bar(1) # E: Argument 1 to "bar" has incompatible type "int"; expected "Sized" + +[builtins fixtures/isinstancelist.pyi] + +[case testBasicSupportsIntProtocol] +from typing import SupportsInt + +class Bar: + def __int__(self): + return 1 + +def foo(a: SupportsInt): + pass + +foo(Bar()) +foo('no way') # E: Argument 1 to "foo" has incompatible type "str"; expected "SupportsInt" + +[builtins fixtures/isinstancelist.pyi] + +-- Additional tests and corner cases for protocols +-- ---------------------------------------------- + +[case testAnyWithProtocols] +from typing import Protocol, Any, TypeVar + +T = TypeVar('T') + +class P1(Protocol): + attr1: int +class P2(Protocol[T]): + attr2: T +class P3(Protocol): + attr: P3 + +def f1(x: P1) -> None: pass +def f2(x: P2[str]) -> None: pass +def f3(x: P3) -> None: pass + +class C1: + attr1: Any +class C2: + attr2: Any +class C3: + attr: Any + +f1(C1()) +f2(C2()) +f3(C3()) + +f2(C3()) # E: Argument 1 to "f2" has incompatible type "C3"; expected P2[str] +a: Any +f1(a) +f2(a) +f3(a) + +[case testErrorsForProtocolsInDifferentPlaces] +from typing import Protocol + +class P(Protocol): + attr1: int + attr2: str + attr3: int + +class C: + attr1: str + @property + def attr2(self) -> int: pass + +x: P = C() # E: Incompatible types in assignment (expression has type "C", variable has type "P") \ + # N: 'C' is missing following 'P' protocol member: \ + # N: attr3 \ + # N: Following member(s) of "C" have conflicts: \ + # N: attr1: expected "int", got "str" \ + # N: attr2: expected "str", got "int" \ + # N: Protocol member P.attr2 expected settable variable, got read-only attribute + +def f(x: P) -> P: + return C() # E: Incompatible return value type (got "C", expected "P") \ + # N: 'C' is missing following 'P' protocol member: \ + # N: attr3 \ + # N: Following member(s) of "C" have conflicts: \ + # N: attr1: expected "int", got "str" \ + # N: attr2: expected "str", got "int" \ + # N: Protocol member P.attr2 expected settable variable, got read-only attribute + +f(C()) # E: Argument 1 to "f" has incompatible type "C"; expected "P" \ + # N: 'C' is missing following 'P' protocol member: \ + # N: attr3 \ + # N: Following member(s) of "C" have conflicts: \ + # N: attr1: expected "int", got "str" \ + # N: attr2: expected "str", got "int" \ + # N: Protocol member P.attr2 expected settable variable, got read-only attribute +[builtins fixtures/list.pyi] + +[case testIterableProtocolOnClass] +from typing import TypeVar, Iterator +T = TypeVar('T', bound='A') + +class A: + def __iter__(self: T) -> Iterator[T]: pass + +class B(A): pass + +reveal_type(list(b for b in B())) # E: Revealed type is 'builtins.list[__main__.B*]' +reveal_type(list(B())) # E: Revealed type is 'builtins.list[__main__.B*]' +[builtins fixtures/list.pyi] + +[case testIterableProtocolOnMetaclass] +from typing import TypeVar, Iterator, Type +T = TypeVar('T') + +class EMeta(type): + def __iter__(self: Type[T]) -> Iterator[T]: pass + +class E(metaclass=EMeta): + pass + +class C(E): + pass + +reveal_type(list(c for c in C)) # E: Revealed type is 'builtins.list[__main__.C*]' +reveal_type(list(C)) # E: Revealed type is 'builtins.list[__main__.C*]' +[builtins fixtures/list.pyi] + +[case testClassesGetattrWithProtocols] +from typing import Protocol + +class P(Protocol): + attr: int + +class PP(Protocol): + @property + def attr(self) -> int: + pass + +class C: + def __getattr__(self, attr: str) -> int: + pass +class C2(C): + def __setattr__(self, attr: str, val: int) -> None: + pass + +class D: + def __getattr__(self, attr: str) -> str: + pass + +def fun(x: P) -> None: + reveal_type(P.attr) # E: Revealed type is 'builtins.int' +def fun_p(x: PP) -> None: + reveal_type(P.attr) # E: Revealed type is 'builtins.int' + +fun(C()) # E: Argument 1 to "fun" has incompatible type "C"; expected "P" \ + # N: Protocol member P.attr expected settable variable, got read-only attribute +fun(C2()) +fun_p(D()) # E: Argument 1 to "fun_p" has incompatible type "D"; expected "PP" \ + # N: Following member(s) of "D" have conflicts: \ + # N: attr: expected "int", got "str" +fun_p(C()) # OK +[builtins fixtures/list.pyi] + +[case testImplicitTypesInProtocols] +from typing import Protocol + +class P(Protocol): + x = 1 # E: All protocol members must have explicitly declared types + +class C: + x: int + +class D: + x: str + +x: P +x = D() # E: Incompatible types in assignment (expression has type "D", variable has type "P") \ + # N: Following member(s) of "D" have conflicts: \ + # N: x: expected "int", got "str" +x = C() # OK +[builtins fixtures/list.pyi] + +[case testProtocolIncompatibilityWithGenericMethod] +from typing import Protocol, TypeVar + +T = TypeVar('T') +S = TypeVar('S') + +class A(Protocol): + def f(self, x: T) -> None: pass +class B: + def f(self, x: S, y: T) -> None: pass + +x: A = B() +[out] +main:11: error: Incompatible types in assignment (expression has type "B", variable has type "A") +main:11: note: Following member(s) of "B" have conflicts: +main:11: note: Expected: +main:11: note: def [T] f(self, x: T) -> None +main:11: note: Got: +main:11: note: def [S, T] f(self, x: S, y: T) -> None + +[case testProtocolIncompatibilityWithGenericMethodBounded] +from typing import Protocol, TypeVar + +T = TypeVar('T') +S = TypeVar('S', bound=int) + +class A(Protocol): + def f(self, x: T) -> None: pass +class B: + def f(self, x: S, y: T) -> None: pass + +x: A = B() +[out] +main:11: error: Incompatible types in assignment (expression has type "B", variable has type "A") +main:11: note: Following member(s) of "B" have conflicts: +main:11: note: Expected: +main:11: note: def [T] f(self, x: T) -> None +main:11: note: Got: +main:11: note: def [S <: int, T] f(self, x: S, y: T) -> None + +[case testProtocolIncompatibilityWithGenericRestricted] +from typing import Protocol, TypeVar + +T = TypeVar('T') +S = TypeVar('S', int, str) + +class A(Protocol): + def f(self, x: T) -> None: pass +class B: + def f(self, x: S, y: T) -> None: pass + +x: A = B() +[out] +main:11: error: Incompatible types in assignment (expression has type "B", variable has type "A") +main:11: note: Following member(s) of "B" have conflicts: +main:11: note: Expected: +main:11: note: def [T] f(self, x: T) -> None +main:11: note: Got: +main:11: note: def [S in (int, str), T] f(self, x: S, y: T) -> None + +[case testProtocolIncompatibilityWithManyOverloads] +from typing import Protocol, overload + +class C1: pass +class C2: pass +class A(Protocol): + @overload + def f(self, x: int) -> int: pass + @overload + def f(self, x: str) -> str: pass + @overload + def f(self, x: C1) -> C2: pass + @overload + def f(self, x: C2) -> C1: pass + +class B: + def f(self) -> None: pass + +x: A = B() +[out] +main:18: error: Incompatible types in assignment (expression has type "B", variable has type "A") +main:18: note: Following member(s) of "B" have conflicts: +main:18: note: Expected: +main:18: note: @overload +main:18: note: def f(self, x: int) -> int +main:18: note: @overload +main:18: note: def f(self, x: str) -> str +main:18: note: <2 more overload(s) not shown> +main:18: note: Got: +main:18: note: def f(self) -> None + +[case testProtocolIncompatibilityWithManyConflicts] +from typing import Protocol + +class A(Protocol): + def f(self, x: int) -> None: pass + def g(self, x: int) -> None: pass + def h(self, x: int) -> None: pass + def i(self, x: int) -> None: pass +class B: + def f(self, x: str) -> None: pass + def g(self, x: str) -> None: pass + def h(self, x: str) -> None: pass + def i(self, x: str) -> None: pass + +x: A = B() +[out] +main:14: error: Incompatible types in assignment (expression has type "B", variable has type "A") +main:14: note: Following member(s) of "B" have conflicts: +main:14: note: Expected: +main:14: note: def f(self, x: int) -> None +main:14: note: Got: +main:14: note: def f(self, x: str) -> None +main:14: note: Expected: +main:14: note: def g(self, x: int) -> None +main:14: note: Got: +main:14: note: def g(self, x: str) -> None +main:14: note: <2 more conflict(s) not shown> + +[case testDontShowNotesForTupleAndIterableProtocol] +from typing import Iterable, Sequence, Protocol, NamedTuple + +class N(NamedTuple): + x: int + +def f1(x: Iterable[str]) -> None: pass +def f2(x: Sequence[str]) -> None: pass + +# The errors below should be short +f1(N(1)) # E: Argument 1 to "f1" has incompatible type "N"; expected Iterable[str] +f2(N(2)) # E: Argument 1 to "f2" has incompatible type "N"; expected Sequence[str] +[builtins fixtures/tuple.pyi] + +[case testNotManyFlagConflitsShownInProtocols] +from typing import Protocol + +class AllSettable(Protocol): + a: int + b: int + c: int + d: int + +class AllReadOnly: + @property + def a(self) -> int: pass + @property + def b(self) -> int: pass + @property + def c(self) -> int: pass + @property + def d(self) -> int: pass + +x: AllSettable = AllReadOnly() +[builtins fixtures/property.pyi] +[out] +main:19: error: Incompatible types in assignment (expression has type "AllReadOnly", variable has type "AllSettable") +main:19: note: Protocol member AllSettable.a expected settable variable, got read-only attribute +main:19: note: Protocol member AllSettable.b expected settable variable, got read-only attribute +main:19: note: <2 more conflict(s) not shown> + +[case testProtocolsMoreConflictsNotShown] +from typing_extensions import Protocol +from typing import Generic, TypeVar + +T = TypeVar('T') + +class MockMapping(Protocol[T]): + def a(self, x: T) -> int: pass + def b(self, x: T) -> int: pass + def c(self, x: T) -> int: pass + d: T + e: T + f: T + +class MockDict(MockMapping[T]): + more: int + +def f(x: MockMapping[int]) -> None: pass +x: MockDict[str] +f(x) # E: Argument 1 to "f" has incompatible type MockDict[str]; expected MockMapping[int] + +[case testProtocolNotesForComplexSignatures] +from typing import Protocol, Optional + +class P(Protocol): + def meth(self, x: int, *args: str) -> None: pass + def other(self, *args, hint: Optional[str] = None, **kwargs: str) -> None: pass +class C: + def meth(self) -> int: pass + def other(self) -> int: pass + +x: P = C() +[builtins fixtures/dict.pyi] +[out] +main:10: error: Incompatible types in assignment (expression has type "C", variable has type "P") +main:10: note: Following member(s) of "C" have conflicts: +main:10: note: Expected: +main:10: note: def meth(self, x: int, *args: str) -> None +main:10: note: Got: +main:10: note: def meth(self) -> int +main:10: note: Expected: +main:10: note: def other(self, *args: Any, hint: Optional[str] = ..., **kwargs: str) -> None +main:10: note: Got: +main:10: note: def other(self) -> int + diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 193c7468d442..a13d1e704f41 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -343,7 +343,9 @@ Point = TypedDict('Point', {'x': int, 'y': int}) def as_dict(p: Point) -> Dict[str, int]: return p # E: Incompatible return value type (got "Point", expected Dict[str, int]) def as_mutable_mapping(p: Point) -> MutableMapping[str, int]: - return p # E: Incompatible return value type (got "Point", expected MutableMapping[str, int]) + return p # E: Incompatible return value type (got "Point", expected MutableMapping[str, int]) \ + # N: 'Point' is missing following 'MutableMapping' protocol member: \ + # N: __setitem__ [builtins fixtures/dict.pyi] [case testCanConvertTypedDictToAny] @@ -372,6 +374,52 @@ ll = [b, c] f(ll) # E: Argument 1 to "f" has incompatible type List[TypedDict({'x': int, 'z': str})]; expected "A" [builtins fixtures/dict.pyi] +[case testTypedDictWithSimpleProtocol] +from typing_extensions import Protocol +from mypy_extensions import TypedDict + +class StrIntMap(Protocol): + def __getitem__(self, key: str) -> int: ... + +A = TypedDict('A', {'x': int, 'y': int}) +B = TypedDict('B', {'x': int, 'y': str}) + +def fun(arg: StrIntMap) -> None: ... +a: A +b: B +fun(a) +fun(b) # Error +[builtins fixtures/dict.pyi] +[out] +main:14: error: Argument 1 to "fun" has incompatible type "B"; expected "StrIntMap" +main:14: note: Following member(s) of "B" have conflicts: +main:14: note: Expected: +main:14: note: def __getitem__(self, str) -> int +main:14: note: Got: +main:14: note: def __getitem__(self, str) -> object + +[case testTypedDictWithSimpleProtocolInference] +from typing_extensions import Protocol +from mypy_extensions import TypedDict +from typing import TypeVar + +T_co = TypeVar('T_co', covariant=True) +T = TypeVar('T') + +class StrMap(Protocol[T_co]): + def __getitem__(self, key: str) -> T_co: ... + +A = TypedDict('A', {'x': int, 'y': int}) +B = TypedDict('B', {'x': int, 'y': str}) + +def fun(arg: StrMap[T]) -> T: + return arg['whatever'] +a: A +b: B +reveal_type(fun(a)) # E: Revealed type is 'builtins.int*' +reveal_type(fun(b)) # E: Revealed type is 'builtins.object*' +[builtins fixtures/dict.pyi] +[out] -- Join @@ -1132,9 +1180,16 @@ def f(x: int) -> None: ... def f(x): pass a: A -f(a) # E: Argument 1 to "f" has incompatible type "A"; expected Iterable[int] +f(a) [builtins fixtures/dict.pyi] [typing fixtures/typing-full.pyi] +[out] +main:13: error: Argument 1 to "f" has incompatible type "A"; expected Iterable[int] +main:13: note: Following member(s) of "A" have conflicts: +main:13: note: Expected: +main:13: note: def __iter__(self) -> Iterator[int] +main:13: note: Got: +main:13: note: def __iter__(self) -> Iterator[str] [case testTypedDictOverloading3] from typing import overload diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi index a271315e4dde..cf8b61f9397a 100644 --- a/test-data/unit/fixtures/dict.pyi +++ b/test-data/unit/fixtures/dict.pyi @@ -11,11 +11,12 @@ class object: class type: pass -class dict(Mapping[KT, VT], Iterable[KT], Generic[KT, VT]): +class dict(Generic[KT, VT]): @overload def __init__(self, **kwargs: VT) -> None: pass @overload def __init__(self, arg: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: pass + def __getitem__(self, key: KT) -> VT: pass def __setitem__(self, k: KT, v: VT) -> None: pass def __iter__(self) -> Iterator[KT]: pass def update(self, a: Mapping[KT, VT]) -> None: pass @@ -23,6 +24,7 @@ class dict(Mapping[KT, VT], Iterable[KT], Generic[KT, VT]): def get(self, k: KT) -> Optional[VT]: pass @overload def get(self, k: KT, default: Union[KT, T]) -> Union[VT, T]: pass + def __len__(self) -> int: ... class int: # for convenience def __add__(self, x: int) -> int: pass @@ -30,7 +32,7 @@ class int: # for convenience class str: pass # for keyword argument key type class unicode: pass # needed for py2 docstrings -class list(Iterable[T], Generic[T]): # needed by some test cases +class list(Generic[T]): # needed by some test cases def __getitem__(self, x: int) -> T: pass def __iter__(self) -> Iterator[T]: pass def __mul__(self, x: int) -> list[T]: pass @@ -41,4 +43,5 @@ class float: pass class bool: pass class ellipsis: pass +def isinstance(x: object, t: Union[type, Tuple]) -> bool: pass class BaseException: pass diff --git a/test-data/unit/fixtures/isinstancelist.pyi b/test-data/unit/fixtures/isinstancelist.pyi index 930d309aa3db..99aca1befe39 100644 --- a/test-data/unit/fixtures/isinstancelist.pyi +++ b/test-data/unit/fixtures/isinstancelist.pyi @@ -23,16 +23,17 @@ T = TypeVar('T') KT = TypeVar('KT') VT = TypeVar('VT') -class tuple(Generic[T]): pass +class tuple(Generic[T]): + def __len__(self) -> int: pass -class list(Iterable[T]): +class list(Generic[T]): def __iter__(self) -> Iterator[T]: pass def __mul__(self, x: int) -> list[T]: pass def __setitem__(self, x: int, v: T) -> None: pass def __getitem__(self, x: int) -> T: pass def __add__(self, x: List[T]) -> T: pass -class dict(Iterable[KT], Mapping[KT, VT]): +class dict(Mapping[KT, VT]): @overload def __init__(self, **kwargs: VT) -> None: pass @overload @@ -41,7 +42,7 @@ class dict(Iterable[KT], Mapping[KT, VT]): def __iter__(self) -> Iterator[KT]: pass def update(self, a: Mapping[KT, VT]) -> None: pass -class set(Iterable[T]): +class set(Generic[T]): def __iter__(self) -> Iterator[T]: pass def add(self, x: T) -> None: pass def discard(self, x: T) -> None: pass diff --git a/test-data/unit/fixtures/list.pyi b/test-data/unit/fixtures/list.pyi index d5d1000d3364..7b6d1dbd127b 100644 --- a/test-data/unit/fixtures/list.pyi +++ b/test-data/unit/fixtures/list.pyi @@ -10,7 +10,7 @@ class object: class type: pass class ellipsis: pass -class list(Iterable[T], Generic[T]): +class list(Generic[T]): @overload def __init__(self) -> None: pass @overload diff --git a/test-data/unit/fixtures/typing-full.pyi b/test-data/unit/fixtures/typing-full.pyi index d43e34065907..62fac70034c0 100644 --- a/test-data/unit/fixtures/typing-full.pyi +++ b/test-data/unit/fixtures/typing-full.pyi @@ -17,6 +17,7 @@ Union = 0 Optional = 0 TypeVar = 0 Generic = 0 +Protocol = 0 Tuple = 0 Callable = 0 _promote = 0 @@ -33,33 +34,42 @@ Dict = 0 Set = 0 T = TypeVar('T') +T_co = TypeVar('T_co', covariant=True) +T_contra = TypeVar('T_contra', contravariant=True) U = TypeVar('U') V = TypeVar('V') S = TypeVar('S') -class Container(Generic[T]): +# Note: definitions below are different from typeshed, variances are declared +# to silence the protocol variance checks. Maybe it is better to use type: ignore? + +@runtime +class Container(Protocol[T_contra]): @abstractmethod # Use int because bool isn't in the default test builtins - def __contains__(self, arg: T) -> int: pass + def __contains__(self, arg: T_contra) -> int: pass -class Sized: +@runtime +class Sized(Protocol): @abstractmethod def __len__(self) -> int: pass -class Iterable(Generic[T]): +@runtime +class Iterable(Protocol[T_co]): @abstractmethod - def __iter__(self) -> 'Iterator[T]': pass + def __iter__(self) -> 'Iterator[T_co]': pass -class Iterator(Iterable[T], Generic[T]): +@runtime +class Iterator(Iterable[T_co], Protocol): @abstractmethod - def __next__(self) -> T: pass + def __next__(self) -> T_co: pass class Generator(Iterator[T], Generic[T, U, V]): @abstractmethod def send(self, value: U) -> T: pass @abstractmethod - def throw(self, typ: Any, val: Any=None, tb=None) -> None: pass + def throw(self, typ: Any, val: Any=None, tb: Any=None) -> None: pass @abstractmethod def close(self) -> None: pass @@ -83,38 +93,52 @@ class AsyncGenerator(AsyncIterator[T], Generic[T, U]): @abstractmethod def __aiter__(self) -> 'AsyncGenerator[T, U]': pass -class Awaitable(Generic[T]): +@runtime +class Awaitable(Protocol[T]): @abstractmethod def __await__(self) -> Generator[Any, Any, T]: pass class AwaitableGenerator(Generator[T, U, V], Awaitable[V], Generic[T, U, V, S]): pass -class AsyncIterable(Generic[T]): +@runtime +class AsyncIterable(Protocol[T]): @abstractmethod def __aiter__(self) -> 'AsyncIterator[T]': pass -class AsyncIterator(AsyncIterable[T], Generic[T]): +@runtime +class AsyncIterator(AsyncIterable[T], Protocol): def __aiter__(self) -> 'AsyncIterator[T]': return self @abstractmethod def __anext__(self) -> Awaitable[T]: pass -class Sequence(Iterable[T], Generic[T]): +@runtime +class Sequence(Iterable[T_co], Protocol): @abstractmethod - def __getitem__(self, n: Any) -> T: pass + def __getitem__(self, n: Any) -> T_co: pass -class Mapping(Iterable[T], Sized, Generic[T, U]): +@runtime +class Mapping(Iterable[T], Protocol[T, T_co]): + def __getitem__(self, key: T) -> T_co: pass @overload - def get(self, k: T) -> Optional[U]: ... + def get(self, k: T) -> Optional[T_co]: pass @overload - def get(self, k: T, default: Union[U, V]) -> Union[U, V]: ... - def values(self) -> Iterable[U]: pass # Approximate return type + def get(self, k: T, default: Union[T_co, V]) -> Union[T_co, V]: pass + def values(self) -> Iterable[T_co]: pass # Approximate return type def __len__(self) -> int: ... -class MutableMapping(Mapping[T, U]): pass +@runtime +class MutableMapping(Mapping[T, U], Protocol): + def __setitem__(self, k: T, v: U) -> None: pass + +class SupportsInt(Protocol): + def __int__(self) -> int: pass + +def runtime(cls: T) -> T: + return cls class ContextManager(Generic[T]): - def __enter__(self) -> T: ... - def __exit__(self, exc_type, exc_value, traceback): ... + def __enter__(self) -> T: pass + def __exit__(self, exc_type, exc_value, traceback): pass TYPE_CHECKING = 1 diff --git a/test-data/unit/lib-stub/builtins.pyi b/test-data/unit/lib-stub/builtins.pyi index 457bea0e9020..87b50f532376 100644 --- a/test-data/unit/lib-stub/builtins.pyi +++ b/test-data/unit/lib-stub/builtins.pyi @@ -15,7 +15,6 @@ class bytes: pass class tuple: pass class function: pass - class ellipsis: pass # Definition of None is implicit diff --git a/test-data/unit/lib-stub/typing.pyi b/test-data/unit/lib-stub/typing.pyi index 02412c77a7b8..8be4abca389a 100644 --- a/test-data/unit/lib-stub/typing.pyi +++ b/test-data/unit/lib-stub/typing.pyi @@ -12,6 +12,7 @@ Union = 0 Optional = 0 TypeVar = 0 Generic = 0 +Protocol = 0 # This is not yet defined in typeshed, see PR typeshed/#1220 Tuple = 0 Callable = 0 _promote = 0 @@ -28,37 +29,57 @@ Dict = 0 Set = 0 T = TypeVar('T') +T_co = TypeVar('T_co', covariant=True) +T_contra = TypeVar('T_contra', contravariant=True) U = TypeVar('U') V = TypeVar('V') S = TypeVar('S') -class Container(Generic[T]): +# Note: definitions below are different from typeshed, variances are declared +# to silence the protocol variance checks. Maybe it is better to use type: ignore? + +@runtime +class Container(Protocol[T_contra]): @abstractmethod # Use int because bool isn't in the default test builtins - def __contains__(self, arg: T) -> int: pass + def __contains__(self, arg: T_contra) -> int: pass -class Sized: +@runtime +class Sized(Protocol): @abstractmethod def __len__(self) -> int: pass -class Iterable(Generic[T]): +@runtime +class Iterable(Protocol[T_co]): @abstractmethod - def __iter__(self) -> 'Iterator[T]': pass + def __iter__(self) -> 'Iterator[T_co]': pass -class Iterator(Iterable[T], Generic[T]): +@runtime +class Iterator(Iterable[T_co], Protocol): @abstractmethod - def __next__(self) -> T: pass + def __next__(self) -> T_co: pass class Generator(Iterator[T], Generic[T, U, V]): @abstractmethod def __iter__(self) -> 'Generator[T, U, V]': pass -class Sequence(Iterable[T], Generic[T]): +@runtime +class Sequence(Iterable[T_co], Protocol): @abstractmethod - def __getitem__(self, n: Any) -> T: pass + def __getitem__(self, n: Any) -> T_co: pass + +@runtime +class Mapping(Protocol[T_contra, T_co]): + def __getitem__(self, key: T_contra) -> T_co: pass + +@runtime +class MutableMapping(Mapping[T_contra, U], Protocol): + def __setitem__(self, k: T_contra, v: U) -> None: pass -class Mapping(Generic[T, U]): pass +class SupportsInt(Protocol): + def __int__(self) -> int: pass -class MutableMapping(Generic[T, U]): pass +def runtime(cls: T) -> T: + return cls TYPE_CHECKING = 1 diff --git a/test-data/unit/lib-stub/typing_extensions.pyi b/test-data/unit/lib-stub/typing_extensions.pyi new file mode 100644 index 000000000000..8c5be8f3637f --- /dev/null +++ b/test-data/unit/lib-stub/typing_extensions.pyi @@ -0,0 +1,6 @@ +from typing import TypeVar + +_T = TypeVar('_T') + +class Protocol: pass +def runtime(x: _T) -> _T: pass diff --git a/test-data/unit/semanal-errors.test b/test-data/unit/semanal-errors.test index 2192bce3221d..0dba8f2836bf 100644 --- a/test-data/unit/semanal-errors.test +++ b/test-data/unit/semanal-errors.test @@ -953,7 +953,7 @@ from typing import Generic, TypeVar T = TypeVar('T') S = TypeVar('S') class A(Generic[T], Generic[S]): pass \ - # E: Duplicate Generic in bases + # E: Only single Generic[...] or Protocol[...] can be in bases [out] [case testInvalidMetaclass]