diff --git a/docs/source/class_basics.rst b/docs/source/class_basics.rst
index dc778d39424f..e03c299e977c 100644
--- a/docs/source/class_basics.rst
+++ b/docs/source/class_basics.rst
@@ -151,7 +151,174 @@ concrete. As with normal overrides, a dynamically typed method can
implement a statically typed abstract method defined in an abstract
base class.
+.. _protocol-types:
+
+Protocols and structural subtyping
+**********************************
+
+.. note::
+
+ The support for structural subtyping is still experimental. Some features
+ might be not yet implemented, mypy could pass unsafe code or reject
+ working code.
+
+There are two main type systems with respect to subtyping: nominal subtyping
+and structural subtyping. The *nominal* subtyping is based on class hierarchy,
+so that if class ``D`` inherits from class ``C``, then it is a subtype
+of ``C``. This type system is primarily used in mypy since it allows
+to produce clear and concise error messages, and since Python provides native
+``isinstance()`` checks based on class hierarchy. The *structural* subtyping
+however has its own advantages. In this system class ``D`` is a subtype
+of class ``C`` if the former has all attributes of the latter with
+compatible types.
+
+This type system is a static equivalent of duck typing, well known by Python
+programmers. Mypy provides an opt-in support for structural subtyping via
+protocol classes described in this section.
+See `PEP 544 `_ for
+specification of protocols and structural subtyping in Python.
+
+User defined protocols
+**********************
+
+To define a protocol class, one must inherit the special
+``typing_extensions.Protocol`` class:
+
+.. code-block:: python
+
+ from typing import Iterable
+ from typing_extensions import Protocol
+
+ class SupportsClose(Protocol):
+ def close(self) -> None:
+ ...
+
+ class Resource: # Note, this class does not have 'SupportsClose' base.
+ # some methods
+ def close(self) -> None:
+ self.resource.release()
+
+ def close_all(things: Iterable[SupportsClose]) -> None:
+ for thing in things:
+ thing.close()
+
+ close_all([Resource(), open('some/file')]) # This passes type check
+
+.. note::
+
+ The ``Protocol`` base class is currently provided in ``typing_extensions``
+ package. When structural subtyping is mature and
+ `PEP 544 `_ is accepted,
+ ``Protocol`` will be included in the ``typing`` module. As well, several
+ types such as ``typing.Sized``, ``typing.Iterable`` etc. will be made
+ protocols.
+
+Defining subprotocols
+*********************
+
+Subprotocols are also supported. Existing protocols can be extended
+and merged using multiple inheritance. For example:
+
+.. code-block:: python
+
+ # continuing from previous example
+
+ class SupportsRead(Protocol):
+ def read(self, amount: int) -> bytes: ...
+
+ class TaggedReadableResource(SupportsClose, SupportsRead, Protocol):
+ label: str
+
+ class AdvancedResource(Resource):
+ def __init__(self, label: str) -> None:
+ self.label = label
+ def read(self, amount: int) -> bytes:
+ # some implementation
+ ...
+
+ resource = None # type: TaggedReadableResource
+
+ # some code
+
+ resource = AdvancedResource('handle with care') # OK
+
+Note that inheriting from existing protocols does not automatically turn
+a subclass into a protocol, it just creates a usual (non-protocol) ABC that
+implements given protocols. The ``typing_extensions.Protocol`` base must always
+be explicitly present:
+
+.. code-block:: python
+
+ class NewProtocol(SupportsClose): # This is NOT a protocol
+ new_attr: int
+
+ class Concrete:
+ new_attr = None # type: int
+ def close(self) -> None:
+ ...
+ # Below is an error, since nominal subtyping is used by default
+ x = Concrete() # type: NewProtocol # Error!
+
+.. note::
+
+ The `PEP 526 `_ variable
+ annotations can be used to declare protocol attributes. However, protocols
+ are also supported on Python 2.7 and Python 3.3+ with the help of type
+ comments and properties, see
+ `backwards compatibility in PEP 544 `_.
+
+Recursive protocols
+*******************
+
+Protocols can be recursive and mutually recursive. This could be useful for
+declaring abstract recursive collections such as trees and linked lists:
+
+.. code-block:: python
+
+ from typing import TypeVar, Optional
+ from typing_extensions import Protocol
+
+ class TreeLike(Protocol):
+ value: int
+ @property
+ def left(self) -> Optional['TreeLike']: ...
+ @property
+ def right(self) -> Optional['TreeLike']: ...
+
+ class SimpleTree:
+ def __init__(self, value: int) -> None:
+ self.value = value
+ self.left: Optional['SimpleTree'] = None
+ self.right: Optional['SimpleTree'] = None
+
+ root = SimpleTree(0) # type: TreeLike # OK
+
+Using ``isinstance()`` with protocols
+*************************************
+
+To use a protocol class with ``isinstance()``, one needs to decorate it with
+a special ``typing_extensions.runtime`` decorator. It will add support for
+basic runtime structural checks:
+
+.. code-block:: python
+
+ from typing_extensions import Protocol, runtime
+
+ @runtime
+ class Portable(Protocol):
+ handles: int
+
+ class Mug:
+ def __init__(self) -> None:
+ self.handles = 1
+
+ mug = Mug()
+ if isinstance(mug, Portable):
+ use(mug.handles) # Works statically and at runtime.
+
.. note::
+ ``isinstance()`` is with protocols not completely safe at runtime.
+ For example, signatures of methods are not checked. The runtime
+ implementation only checks the presence of all protocol members
+ in object's MRO.
- There are also plans to support more Python-style "duck typing" in
- the type system. The details are still open.
diff --git a/docs/source/common_issues.rst b/docs/source/common_issues.rst
index 0c8b500d8f06..e31d9a7bc4e6 100644
--- a/docs/source/common_issues.rst
+++ b/docs/source/common_issues.rst
@@ -226,6 +226,48 @@ Possible strategies in such situations are:
return x[0]
f_good(new_lst) # OK
+Covariant subtyping of mutable protocol members is rejected
+-----------------------------------------------------------
+
+Mypy rejects this because this is potentially unsafe.
+Consider this example:
+
+.. code-block:: python
+
+ from typing_extensions import Protocol
+
+ class P(Protocol):
+ x: float
+
+ def fun(arg: P) -> None:
+ arg.x = 3.14
+
+ class C:
+ x = 42
+ c = C()
+ fun(c) # This is not safe
+ c.x << 5 # Since this will fail!
+
+To work around this problem consider whether "mutating" is actually part
+of a protocol. If not, then one can use a ``@property`` in
+the protocol definition:
+
+.. code-block:: python
+
+ from typing_extensions import Protocol
+
+ class P(Protocol):
+ @property
+ def x(self) -> float:
+ pass
+
+ def fun(arg: P) -> None:
+ ...
+
+ class C:
+ x = 42
+ fun(C()) # OK
+
Declaring a supertype as variable type
--------------------------------------
diff --git a/docs/source/faq.rst b/docs/source/faq.rst
index 9fd73b42b8a3..b131e3f7e2a2 100644
--- a/docs/source/faq.rst
+++ b/docs/source/faq.rst
@@ -101,35 +101,38 @@ Is mypy free?
Yes. Mypy is free software, and it can also be used for commercial and
proprietary projects. Mypy is available under the MIT license.
-Why not use structural subtyping?
-*********************************
+Can I use structural subtyping?
+*******************************
-Mypy primarily uses `nominal subtyping
-`_ instead of
+Mypy provides support for both `nominal subtyping
+`_ and
`structural subtyping
-`_. Some argue
-that structural subtyping is better suited for languages with duck
-typing such as Python.
-
-Here are some reasons why mypy uses nominal subtyping:
+`_.
+Support for structural subtyping is considered experimental.
+Some argue that structural subtyping is better suited for languages with duck
+typing such as Python. Mypy however primarily uses nominal subtyping,
+leaving structural subtyping opt-in. Here are some reasons why:
1. It is easy to generate short and informative error messages when
using a nominal type system. This is especially important when
using type inference.
-2. Python supports basically nominal isinstance tests and they are
- widely used in programs. It is not clear how to support isinstance
- in a purely structural type system while remaining compatible with
- Python idioms.
+2. Python provides built-in support for nominal ``isinstance()`` tests and
+ they are widely used in programs. Only limited support for structural
+ ``isinstance()`` exists for ABCs in ``collections.abc`` and ``typing``
+ standard library modules.
3. Many programmers are already familiar with nominal subtyping and it
has been successfully used in languages such as Java, C++ and
C#. Only few languages use structural subtyping.
-However, structural subtyping can also be useful. Structural subtyping
-is a likely feature to be added to mypy in the future, even though we
-expect that most mypy programs will still primarily use nominal
-subtyping.
+However, structural subtyping can also be useful. For example, a "public API"
+may be more flexible if it is typed with protocols. Also, using protocol types
+removes the necessity to explicitly declare implementations of ABCs.
+As a rule of thumb, we recommend using nominal classes where possible, and
+protocols where necessary. For more details about protocol types and structural
+subtyping see :ref:`protocol-types` and
+`PEP 544 `_.
I like Python and I have no need for static typing
**************************************************
diff --git a/docs/source/generics.rst b/docs/source/generics.rst
index bd0e0549fd1a..8f8f5355b93e 100644
--- a/docs/source/generics.rst
+++ b/docs/source/generics.rst
@@ -1,6 +1,8 @@
Generics
========
+.. _generic-classes:
+
Defining generic classes
************************
@@ -489,6 +491,73 @@ restrict the valid values for the type parameter in the same way.
A type variable may not have both a value restriction (see
:ref:`type-variable-value-restriction`) and an upper bound.
+Generic protocols
+*****************
+
+Generic protocols (see :ref:`protocol-types`) are also supported, generic
+protocols mostly follow the normal rules for generic classes, the main
+difference is that mypy checks that declared variance of type variables is
+compatible with the class definition. Examples:
+
+.. code-block:: python
+
+ from typing import TypeVar
+ from typing_extensions import Protocol
+
+ T = TypeVar('T')
+
+ class Box(Protocol[T]):
+ content: T
+
+ def do_stuff(one: Box[str], other: Box[bytes]) -> None:
+ ...
+
+ class StringWrapper:
+ def __init__(self, content: str) -> None:
+ self.content = content
+
+ class BytesWrapper:
+ def __init__(self, content: bytes) -> None:
+ self.content = content
+
+ do_stuff(StringWrapper('one'), BytesWrapper(b'other')) # OK
+
+ x = None # type: Box[float]
+ y = None # type: Box[int]
+ x = y # Error, since the protocol 'Box' is invariant.
+
+ class AnotherBox(Protocol[T]): # Error, covariant type variable expected
+ def content(self) -> T:
+ ...
+
+ T_co = TypeVar('T_co', covariant=True)
+ class AnotherBox(Protocol[T_co]): # OK
+ def content(self) -> T_co:
+ ...
+
+ ax = None # type: AnotherBox[float]
+ ay = None # type: AnotherBox[int]
+ ax = ay # OK for covariant protocols
+
+See :ref:`variance-of-generics` above for more details on variance.
+Generic protocols can be recursive, for example:
+
+.. code-block:: python
+
+ T = TypeVar('T')
+ class Linked(Protocol[T]):
+ val: T
+ def next(self) -> 'Linked[T]': ...
+
+ class L:
+ val: int
+ def next(self) -> 'L': ...
+
+ def last(seq: Linked[T]) -> T:
+ ...
+
+ result = last(L()) # The inferred type of 'result' is 'int'
+
.. _declaring-decorators:
Declaring decorators
diff --git a/mypy/checker.py b/mypy/checker.py
index ca9cf379df06..16329871b0bc 100644
--- a/mypy/checker.py
+++ b/mypy/checker.py
@@ -27,7 +27,7 @@
RefExpr, YieldExpr, BackquoteExpr, Import, ImportFrom, ImportAll, ImportBase,
AwaitExpr, PromoteExpr, Node, EnumCallExpr,
ARG_POS, MDEF,
- CONTRAVARIANT, COVARIANT)
+ CONTRAVARIANT, COVARIANT, INVARIANT)
from mypy import nodes
from mypy.typeanal import has_any_from_unimported_type, check_for_explicit_any
from mypy.types import (
@@ -44,7 +44,7 @@
from mypy.subtypes import (
is_subtype, is_equivalent, is_proper_subtype, is_more_precise,
restrict_subtype_away, is_subtype_ignoring_tvars, is_callable_subtype,
- unify_generic_callable,
+ unify_generic_callable, find_member
)
from mypy.maptype import map_instance_to_supertype
from mypy.typevars import fill_typevars, has_no_typevars
@@ -1147,6 +1147,8 @@ def erase_override(t: Type) -> Type:
def visit_class_def(self, defn: ClassDef) -> None:
"""Type check a class definition."""
typ = defn.info
+ if typ.is_protocol and typ.defn.type_vars:
+ self.check_protocol_variance(defn)
with self.errors.enter_type(defn.name), self.enter_partial_types():
old_binder = self.binder
self.binder = ConditionalTypeBinder()
@@ -1158,6 +1160,33 @@ def visit_class_def(self, defn: ClassDef) -> None:
# Otherwise we've already found errors; more errors are not useful
self.check_multiple_inheritance(typ)
+ def check_protocol_variance(self, defn: ClassDef) -> None:
+ """Check that protocol definition is compatible with declared
+ variances of type variables.
+
+ Note that we also prohibit declaring protocol classes as invariant
+ if they are actually covariant/contravariant, since this may break
+ transitivity of subtyping, see PEP 544.
+ """
+ info = defn.info
+ object_type = Instance(info.mro[-1], [])
+ tvars = info.defn.type_vars
+ for i, tvar in enumerate(tvars):
+ up_args = [object_type if i == j else AnyType(TypeOfAny.special_form)
+ for j, _ in enumerate(tvars)]
+ down_args = [UninhabitedType() if i == j else AnyType(TypeOfAny.special_form)
+ for j, _ in enumerate(tvars)]
+ up, down = Instance(info, up_args), Instance(info, down_args)
+ # TODO: add advanced variance checks for recursive protocols
+ if is_subtype(down, up, ignore_declared_variance=True):
+ expected = COVARIANT
+ elif is_subtype(up, down, ignore_declared_variance=True):
+ expected = CONTRAVARIANT
+ else:
+ expected = INVARIANT
+ if expected != tvar.variance:
+ self.msg.bad_proto_variance(tvar.variance, tvar.name, expected, defn)
+
def check_multiple_inheritance(self, typ: TypeInfo) -> None:
"""Check for multiple inheritance related errors."""
if len(typ.bases) <= 1:
@@ -1328,15 +1357,16 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
else:
rvalue_type = self.check_simple_assignment(lvalue_type, rvalue, lvalue)
- # Special case: only non-abstract classes can be assigned to variables
- # with explicit type Type[A].
+ # Special case: only non-abstract non-protocol classes can be assigned to
+ # variables with explicit type Type[A], where A is protocol or abstract.
if (isinstance(rvalue_type, CallableType) and rvalue_type.is_type_obj() and
- rvalue_type.type_object().is_abstract and
+ (rvalue_type.type_object().is_abstract or
+ rvalue_type.type_object().is_protocol) and
isinstance(lvalue_type, TypeType) and
isinstance(lvalue_type.item, Instance) and
- lvalue_type.item.type.is_abstract):
- self.fail("Can only assign non-abstract classes"
- " to a variable of type '{}'".format(lvalue_type), rvalue)
+ (lvalue_type.item.type.is_abstract or
+ lvalue_type.item.type.is_protocol)):
+ self.msg.concrete_only_assign(lvalue_type, rvalue)
return
if rvalue_type and infer_lvalue_type:
self.binder.assign_type(lvalue, rvalue_type, lvalue_type, False)
@@ -2468,6 +2498,13 @@ def check_subtype(self, subtype: Type, supertype: Type, context: Context,
self.fail(msg, context)
if note_msg:
self.note(note_msg, context)
+ if (isinstance(supertype, Instance) and supertype.type.is_protocol and
+ isinstance(subtype, (Instance, TupleType, TypedDictType))):
+ self.msg.report_protocol_problems(subtype, supertype, context)
+ if isinstance(supertype, CallableType) and isinstance(subtype, Instance):
+ call = find_member('__call__', subtype, subtype)
+ if call:
+ self.msg.note_call(subtype, call, context)
return False
def contains_none(self, t: Type) -> bool:
@@ -2609,9 +2646,9 @@ def warn(self, msg: str, context: Context) -> None:
"""Produce a warning message."""
self.msg.warn(msg, context)
- def note(self, msg: str, context: Context) -> None:
+ def note(self, msg: str, context: Context, offset: int = 0) -> None:
"""Produce a note."""
- self.msg.note(msg, context)
+ self.msg.note(msg, context, offset=offset)
def iterable_item_type(self, instance: Instance) -> Type:
iterable = map_instance_to_supertype(
diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py
index a26637dfd544..1d50aa749274 100644
--- a/mypy/checkexpr.py
+++ b/mypy/checkexpr.py
@@ -35,7 +35,7 @@
from mypy import join
from mypy.meet import narrow_declared_type
from mypy.maptype import map_instance_to_supertype
-from mypy.subtypes import is_subtype, is_equivalent
+from mypy.subtypes import is_subtype, is_equivalent, find_member
from mypy import applytype
from mypy import erasetype
from mypy.checkmember import analyze_member_access, type_object_type, bind_self
@@ -245,6 +245,15 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type:
callee_type = self.apply_method_signature_hook(
e, callee_type, object_type, signature_hook)
ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname, object_type)
+ if (isinstance(e.callee, RefExpr) and len(e.args) == 2 and
+ e.callee.fullname in ('builtins.isinstance', 'builtins.issubclass')):
+ for expr in mypy.checker.flatten(e.args[1]):
+ tp = self.chk.type_map[expr]
+ if (isinstance(tp, CallableType) and tp.is_type_obj() and
+ tp.type_object().is_protocol and
+ not tp.type_object().runtime_protocol):
+ self.chk.fail('Only @runtime protocols can be used with'
+ ' instance and class checks', e)
if isinstance(ret_type, UninhabitedType):
self.chk.binder.unreachable()
if not allow_none_return and isinstance(ret_type, NoneTyp):
@@ -502,6 +511,11 @@ def check_call(self, callee: Type, args: List[Expression],
self.msg.cannot_instantiate_abstract_class(
callee.type_object().name(), type.abstract_attributes,
context)
+ elif (callee.is_type_obj() and callee.type_object().is_protocol
+ # Exceptions for Type[...] and classmethod first argument
+ and not callee.from_type_type and not callee.is_classmethod_class):
+ self.chk.fail('Cannot instantiate protocol class "{}"'
+ .format(callee.type_object().name()), context)
formal_to_actual = map_actuals_to_formals(
arg_kinds, arg_names,
@@ -1001,19 +1015,28 @@ def check_arg(self, caller_type: Type, original_caller_type: Type,
"""Check the type of a single argument in a call."""
if isinstance(caller_type, DeletedType):
messages.deleted_as_rvalue(caller_type, context)
- # Only non-abstract class can be given where Type[...] is expected...
+ # Only non-abstract non-protocol class can be given where Type[...] is expected...
elif (isinstance(caller_type, CallableType) and isinstance(callee_type, TypeType) and
- caller_type.is_type_obj() and caller_type.type_object().is_abstract and
- isinstance(callee_type.item, Instance) and callee_type.item.type.is_abstract and
+ caller_type.is_type_obj() and
+ (caller_type.type_object().is_abstract or caller_type.type_object().is_protocol) and
+ isinstance(callee_type.item, Instance) and
+ (callee_type.item.type.is_abstract or callee_type.item.type.is_protocol) and
# ...except for classmethod first argument
not caller_type.is_classmethod_class):
- messages.fail("Only non-abstract class can be given where '{}' is expected"
- .format(callee_type), context)
+ self.msg.concrete_only_call(callee_type, context)
elif not is_subtype(caller_type, callee_type):
if self.chk.should_suppress_optional_error([caller_type, callee_type]):
return
messages.incompatible_argument(n, m, callee, original_caller_type,
caller_kind, context)
+ if (isinstance(original_caller_type, (Instance, TupleType, TypedDictType)) and
+ isinstance(callee_type, Instance) and callee_type.type.is_protocol):
+ self.msg.report_protocol_problems(original_caller_type, callee_type, context)
+ if (isinstance(callee_type, CallableType) and
+ isinstance(original_caller_type, Instance)):
+ call = find_member('__call__', original_caller_type, original_caller_type)
+ if call:
+ self.msg.note_call(original_caller_type, call, context)
def overload_call_target(self, arg_types: List[Type], arg_kinds: List[int],
arg_names: List[str],
@@ -2697,6 +2720,8 @@ def overload_arg_similarity(actual: Type, formal: Type) -> int:
# subtyping algorithm if type promotions are possible (e.g., int vs. float).
if formal.type in actual.type.mro:
return 2
+ elif formal.type.is_protocol and is_subtype(actual, erasetype.erase_type(formal)):
+ return 2
elif actual.type._promote and is_subtype(actual, formal):
return 1
else:
diff --git a/mypy/checkmember.py b/mypy/checkmember.py
index 93e210523c46..ceb9a4ff00ce 100644
--- a/mypy/checkmember.py
+++ b/mypy/checkmember.py
@@ -53,6 +53,8 @@ def analyze_member_access(name: str,
the fallback type, for example.
original_type is always the type used in the initial call.
"""
+ # TODO: this and following functions share some logic with subtypes.find_member,
+ # consider refactoring.
if isinstance(typ, Instance):
if name == '__init__' and not is_super:
# Accessing __init__ in statically typed code would compromise
diff --git a/mypy/constraints.py b/mypy/constraints.py
index 0aa12322bc70..7180dd00f25f 100644
--- a/mypy/constraints.py
+++ b/mypy/constraints.py
@@ -305,7 +305,7 @@ def visit_type_var(self, template: TypeVarType) -> List[Constraint]:
# Non-leaf types
def visit_instance(self, template: Instance) -> List[Constraint]:
- actual = self.actual
+ original_actual = actual = self.actual
res = [] # type: List[Constraint]
if isinstance(actual, CallableType) and actual.fallback is not None:
actual = actual.fallback
@@ -313,6 +313,8 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
actual = actual.as_anonymous().fallback
if isinstance(actual, Instance):
instance = actual
+ # We always try nominal inference if possible,
+ # it is much faster than the structural one.
if (self.direction == SUBTYPE_OF and
template.type.has_base(instance.type.fullname())):
mapped = map_instance_to_supertype(template, instance.type)
@@ -336,6 +338,28 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
res.extend(infer_constraints(
template.args[j], mapped.args[j], neg_op(self.direction)))
return res
+ if (template.type.is_protocol and self.direction == SUPERTYPE_OF and
+ # We avoid infinite recursion for structural subtypes by checking
+ # whether this type already appeared in the inference chain.
+ # This is a conservative way break the inference cycles.
+ # It never produces any "false" constraints but gives up soon
+ # on purely structural inference cycles, see #3829.
+ not any(is_same_type(template, t) for t in template.type.inferring) and
+ mypy.subtypes.is_subtype(instance, erase_typevars(template))):
+ template.type.inferring.append(template)
+ self.infer_constraints_from_protocol_members(res, instance, template,
+ original_actual, template)
+ template.type.inferring.pop()
+ return res
+ elif (instance.type.is_protocol and self.direction == SUBTYPE_OF and
+ # We avoid infinite recursion for structural subtypes also here.
+ not any(is_same_type(instance, i) for i in instance.type.inferring) and
+ mypy.subtypes.is_subtype(erase_typevars(template), instance)):
+ instance.type.inferring.append(instance)
+ self.infer_constraints_from_protocol_members(res, instance, template,
+ template, instance)
+ instance.type.inferring.pop()
+ return res
if isinstance(actual, AnyType):
# IDEA: Include both ways, i.e. add negation as well?
return self.infer_against_any(template.args, actual)
@@ -349,9 +373,36 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
cb = infer_constraints(template.args[0], item, SUPERTYPE_OF)
res.extend(cb)
return res
+ elif (isinstance(actual, TupleType) and template.type.is_protocol and
+ self.direction == SUPERTYPE_OF):
+ if mypy.subtypes.is_subtype(actual.fallback, erase_typevars(template)):
+ res.extend(infer_constraints(template, actual.fallback, self.direction))
+ return res
+ return []
else:
return []
+ def infer_constraints_from_protocol_members(self, res: List[Constraint],
+ instance: Instance, template: Instance,
+ subtype: Type, protocol: Instance) -> None:
+ """Infer constraints for situations where either 'template' or 'instance' is a protocol.
+
+ The 'protocol' is the one of two that is an instance of protocol type, 'subtype'
+ is the type used to bind self during inference. Currently, we just infer constrains for
+ every protocol member type (both ways for settable members).
+ """
+ for member in protocol.type.protocol_members:
+ inst = mypy.subtypes.find_member(member, instance, subtype)
+ temp = mypy.subtypes.find_member(member, template, subtype)
+ assert inst is not None and temp is not None
+ # The above is safe since at this point we know that 'instance' is a subtype
+ # of (erased) 'template', therefore it defines all protocol members
+ res.extend(infer_constraints(temp, inst, self.direction))
+ if (mypy.subtypes.IS_SETTABLE in
+ mypy.subtypes.get_member_flags(member, protocol.type)):
+ # Settable members are invariant, add opposite constraints
+ res.extend(infer_constraints(temp, inst, neg_op(self.direction)))
+
def visit_callable_type(self, template: CallableType) -> List[Constraint]:
if isinstance(self.actual, CallableType):
cactual = self.actual
@@ -379,6 +430,14 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]:
return self.infer_against_overloaded(self.actual, template)
elif isinstance(self.actual, TypeType):
return infer_constraints(template.ret_type, self.actual.item, self.direction)
+ elif isinstance(self.actual, Instance):
+ # Instances with __call__ method defined are considered structural
+ # subtypes of Callable with a compatible signature.
+ call = mypy.subtypes.find_member('__call__', self.actual, self.actual)
+ if call:
+ return infer_constraints(template, call, self.direction)
+ else:
+ return []
else:
return []
diff --git a/mypy/errors.py b/mypy/errors.py
index 1da550a8e68b..eb6a17efaf3c 100644
--- a/mypy/errors.py
+++ b/mypy/errors.py
@@ -11,6 +11,7 @@
T = TypeVar('T')
+allowed_duplicates = ['@overload', 'Got:', 'Expected:']
class ErrorInfo:
@@ -261,7 +262,8 @@ def report(self,
severity: str = 'error',
file: Optional[str] = None,
only_once: bool = False,
- origin_line: Optional[int] = None) -> None:
+ origin_line: Optional[int] = None,
+ offset: int = 0) -> None:
"""Report message at the given line using the current error context.
Args:
@@ -278,6 +280,8 @@ def report(self,
type = None # Omit type context if nested function
if file is None:
file = self.file
+ if offset:
+ message = " " * offset + message
info = ErrorInfo(self.import_context(), file, self.current_module(), type,
self.function_or_member[-1], line, column, severity, message,
blocker, only_once,
@@ -479,6 +483,10 @@ def remove_duplicates(self, errors: List[Tuple[Optional[str], int, int, str, str
while (j >= 0 and errors[j][0] == errors[i][0] and
errors[j][1] == errors[i][1]):
if (errors[j][3] == errors[i][3] and
+ # Allow duplicate notes in overload conficts reporting
+ not (errors[i][3] == 'note' and
+ errors[i][4].strip() in allowed_duplicates
+ or errors[i][4].strip().startswith('def ')) and
errors[j][4] == errors[i][4]): # ignore column
dup = True
break
diff --git a/mypy/fastparse.py b/mypy/fastparse.py
index 12b7e2db69bd..8ba3f193e962 100644
--- a/mypy/fastparse.py
+++ b/mypy/fastparse.py
@@ -514,7 +514,7 @@ def visit_Assign(self, n: ast3.Assign) -> AssignmentStmt:
@with_line
def visit_AnnAssign(self, n: ast3.AnnAssign) -> AssignmentStmt:
if n.value is None: # always allow 'x: int'
- rvalue = TempNode(AnyType(TypeOfAny.special_form)) # type: Expression
+ rvalue = TempNode(AnyType(TypeOfAny.special_form), no_rhs=True) # type: Expression
else:
rvalue = self.visit(n.value)
typ = TypeConverter(self.errors, line=n.lineno).visit(n.annotation)
diff --git a/mypy/join.py b/mypy/join.py
index c771b6df86b9..e306a3f35895 100644
--- a/mypy/join.py
+++ b/mypy/join.py
@@ -9,7 +9,10 @@
PartialType, DeletedType, UninhabitedType, TypeType, true_or_false, TypeOfAny
)
from mypy.maptype import map_instance_to_supertype
-from mypy.subtypes import is_subtype, is_equivalent, is_subtype_ignoring_tvars, is_proper_subtype
+from mypy.subtypes import (
+ is_subtype, is_equivalent, is_subtype_ignoring_tvars, is_proper_subtype,
+ is_protocol_implementation
+)
from mypy import experiments
@@ -137,7 +140,18 @@ def visit_type_var(self, t: TypeVarType) -> Type:
def visit_instance(self, t: Instance) -> Type:
if isinstance(self.s, Instance):
- return join_instances(t, self.s)
+ nominal = join_instances(t, self.s)
+ structural = None # type: Optional[Instance]
+ if t.type.is_protocol and is_protocol_implementation(self.s, t):
+ structural = t
+ elif self.s.type.is_protocol and is_protocol_implementation(t, self.s):
+ structural = self.s
+ # Structural join is preferred in the case where we have found both
+ # structural and nominal and they have same MRO length (see two comments
+ # in join_instances_via_supertype). Otherwise, just return the nominal join.
+ if not structural or is_better(nominal, structural):
+ return nominal
+ return structural
elif isinstance(self.s, FunctionLike):
return join_types(t, self.s.fallback)
elif isinstance(self.s, TypeType):
@@ -237,7 +251,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
required_keys = set(items.keys()) & t.required_keys & self.s.required_keys
return TypedDictType(items, required_keys, fallback)
elif isinstance(self.s, Instance):
- return join_instances(self.s, t.fallback)
+ return join_types(self.s, t.fallback)
else:
return self.default(self.s)
diff --git a/mypy/meet.py b/mypy/meet.py
index 632e0bbc5afa..33d7f6e3df4a 100644
--- a/mypy/meet.py
+++ b/mypy/meet.py
@@ -7,7 +7,7 @@
TupleType, TypedDictType, ErasedType, TypeList, UnionType, PartialType, DeletedType,
UninhabitedType, TypeType, TypeOfAny
)
-from mypy.subtypes import is_equivalent, is_subtype
+from mypy.subtypes import is_equivalent, is_subtype, is_protocol_implementation
from mypy import experiments
@@ -52,7 +52,8 @@ def is_overlapping_types(t: Type, s: Type, use_promotions: bool = False) -> bool
Note that this effectively checks against erased types, since type
variables are erased at runtime and the overlapping check is based
- on runtime behavior.
+ on runtime behavior. The exception is protocol types, it is not safe,
+ but convenient and is an opt-in behavior.
If use_promotions is True, also consider type promotions (int and
float would only be overlapping if it's True).
@@ -100,7 +101,13 @@ class C(A, B): ...
return True
if s.type._promote and is_overlapping_types(s.type._promote, t):
return True
- return t.type in s.type.mro or s.type in t.type.mro
+ if t.type in s.type.mro or s.type in t.type.mro:
+ return True
+ if t.type.is_protocol and is_protocol_implementation(s, t):
+ return True
+ if s.type.is_protocol and is_protocol_implementation(t, s):
+ return True
+ return False
if isinstance(t, UnionType):
return any(is_overlapping_types(item, s)
for item in t.relevant_items())
diff --git a/mypy/messages.py b/mypy/messages.py
index 022d962f6706..21bf2bf85334 100644
--- a/mypy/messages.py
+++ b/mypy/messages.py
@@ -6,7 +6,7 @@
import re
import difflib
-from typing import cast, List, Dict, Any, Sequence, Iterable, Tuple, Optional
+from typing import cast, List, Dict, Any, Sequence, Iterable, Tuple, Set, Optional, Union
from mypy.erasetype import erase_type
from mypy.errors import Errors
@@ -18,7 +18,7 @@
from mypy.nodes import (
TypeInfo, Context, MypyFile, op_methods, FuncDef, reverse_type_aliases,
ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2,
- ReturnStmt, NameExpr, Var
+ ReturnStmt, NameExpr, Var, CONTRAVARIANT, COVARIANT
)
@@ -159,12 +159,13 @@ def is_errors(self) -> bool:
return self.errors.is_errors()
def report(self, msg: str, context: Context, severity: str,
- file: Optional[str] = None, origin: Optional[Context] = None) -> None:
+ file: Optional[str] = None, origin: Optional[Context] = None,
+ offset: int = 0) -> None:
"""Report an error or note (unless disabled)."""
if self.disable_count <= 0:
self.errors.report(context.get_line() if context else -1,
context.get_column() if context else -1,
- msg.strip(), severity=severity, file=file,
+ msg.strip(), severity=severity, file=file, offset=offset,
origin_line=origin.get_line() if origin else None)
def fail(self, msg: str, context: Context, file: Optional[str] = None,
@@ -173,9 +174,9 @@ def fail(self, msg: str, context: Context, file: Optional[str] = None,
self.report(msg, context, 'error', file=file, origin=origin)
def note(self, msg: str, context: Context, file: Optional[str] = None,
- origin: Optional[Context] = None) -> None:
+ origin: Optional[Context] = None, offset: int = 0) -> None:
"""Report a note (unless disabled)."""
- self.report(msg, context, 'note', file=file, origin=origin)
+ self.report(msg, context, 'note', file=file, origin=origin, offset=offset)
def warn(self, msg: str, context: Context, file: Optional[str] = None,
origin: Optional[Context] = None) -> None:
@@ -998,6 +999,252 @@ def untyped_decorated_function(self, typ: Type, context: Context) -> None:
def typed_function_untyped_decorator(self, func_name: str, context: Context) -> None:
self.fail('Untyped decorator makes function "{}" untyped'.format(func_name), context)
+ def bad_proto_variance(self, actual: int, tvar_name: str, expected: int,
+ context: Context) -> None:
+ msg = capitalize("{} type variable '{}' used in protocol where"
+ " {} one is expected".format(variance_string(actual),
+ tvar_name,
+ variance_string(expected)))
+ self.fail(msg, context)
+
+ def concrete_only_assign(self, typ: Type, context: Context) -> None:
+ self.fail("Can only assign concrete classes to a variable of type '{}'"
+ .format(self.format(typ)), context)
+
+ def concrete_only_call(self, typ: Type, context: Context) -> None:
+ self.fail("Only concrete class can be given where '{}' is expected"
+ .format(self.format(typ)), context)
+
+ def note_call(self, subtype: Type, call: Type, context: Context) -> None:
+ self.note("'{}.__call__' has type '{}'".format(strip_quotes(self.format(subtype)),
+ self.format(call, verbosity=1)), context)
+
+ def report_protocol_problems(self, subtype: Union[Instance, TupleType, TypedDictType],
+ supertype: Instance, context: Context) -> None:
+ """Report possible protocol conflicts between 'subtype' and 'supertype'.
+
+ This includes missing members, incompatible types, and incompatible
+ attribute flags, such as settable vs read-only or class variable vs
+ instance variable.
+ """
+ from mypy.subtypes import is_subtype, IS_SETTABLE, IS_CLASSVAR, IS_CLASS_OR_STATIC
+ OFFSET = 4 # Four spaces, so that notes will look like this:
+ # note: 'Cls' is missing following 'Proto' members:
+ # note: method, attr
+ MAX_ITEMS = 2 # Maximum number of conflicts, missing members, and overloads shown
+ # List of special situations where we don't want to report additional problems
+ exclusions = {TypedDictType: ['typing.Mapping'],
+ TupleType: ['typing.Iterable', 'typing.Sequence'],
+ Instance: []} # type: Dict[type, List[str]]
+ if supertype.type.fullname() in exclusions[type(subtype)]:
+ return
+ if any(isinstance(tp, UninhabitedType) for tp in supertype.args):
+ # We don't want to add notes for failed inference (e.g. Iterable[]).
+ # This will be only confusing a user even more.
+ return
+
+ if isinstance(subtype, (TupleType, TypedDictType)):
+ if not isinstance(subtype.fallback, Instance):
+ return
+ subtype = subtype.fallback
+
+ # Report missing members
+ missing = get_missing_protocol_members(subtype, supertype)
+ if (missing and len(missing) < len(supertype.type.protocol_members) and
+ len(missing) <= MAX_ITEMS):
+ self.note("'{}' is missing following '{}' protocol member{}:"
+ .format(subtype.type.name(), supertype.type.name(), plural_s(missing)),
+ context)
+ self.note(', '.join(missing), context, offset=OFFSET)
+ elif len(missing) > MAX_ITEMS or len(missing) == len(supertype.type.protocol_members):
+ # This is an obviously wrong type: too many missing members
+ return
+
+ # Report member type conflicts
+ conflict_types = get_conflict_protocol_types(subtype, supertype)
+ if conflict_types and (not is_subtype(subtype, erase_type(supertype)) or
+ not subtype.type.defn.type_vars or
+ not supertype.type.defn.type_vars):
+ self.note('Following member(s) of {} have '
+ 'conflicts:'.format(self.format(subtype)), context)
+ for name, got, exp in conflict_types[:MAX_ITEMS]:
+ if (not isinstance(exp, (CallableType, Overloaded)) or
+ not isinstance(got, (CallableType, Overloaded))):
+ self.note('{}: expected {}, got {}'.format(name,
+ *self.format_distinctly(exp, got)),
+ context, offset=OFFSET)
+ else:
+ self.note('Expected:', context, offset=OFFSET)
+ if isinstance(exp, CallableType):
+ self.note(self.pretty_callable(exp), context, offset=2 * OFFSET)
+ else:
+ assert isinstance(exp, Overloaded)
+ self.pretty_overload(exp, context, OFFSET, MAX_ITEMS)
+ self.note('Got:', context, offset=OFFSET)
+ if isinstance(got, CallableType):
+ self.note(self.pretty_callable(got), context, offset=2 * OFFSET)
+ else:
+ assert isinstance(got, Overloaded)
+ self.pretty_overload(got, context, OFFSET, MAX_ITEMS)
+ self.print_more(conflict_types, context, OFFSET, MAX_ITEMS)
+
+ # Report flag conflicts (i.e. settable vs read-only etc.)
+ conflict_flags = get_bad_protocol_flags(subtype, supertype)
+ for name, subflags, superflags in conflict_flags[:MAX_ITEMS]:
+ if IS_CLASSVAR in subflags and IS_CLASSVAR not in superflags:
+ self.note('Protocol member {}.{} expected instance variable,'
+ ' got class variable'.format(supertype.type.name(), name), context)
+ if IS_CLASSVAR in superflags and IS_CLASSVAR not in subflags:
+ self.note('Protocol member {}.{} expected class variable,'
+ ' got instance variable'.format(supertype.type.name(), name), context)
+ if IS_SETTABLE in superflags and IS_SETTABLE not in subflags:
+ self.note('Protocol member {}.{} expected settable variable,'
+ ' got read-only attribute'.format(supertype.type.name(), name), context)
+ if IS_CLASS_OR_STATIC in superflags and IS_CLASS_OR_STATIC not in subflags:
+ self.note('Protocol member {}.{} expected class or static method'
+ .format(supertype.type.name(), name), context)
+ self.print_more(conflict_flags, context, OFFSET, MAX_ITEMS)
+
+ def pretty_overload(self, tp: Overloaded, context: Context,
+ offset: int, max_items: int) -> None:
+ for item in tp.items()[:max_items]:
+ self.note('@overload', context, offset=2 * offset)
+ self.note(self.pretty_callable(item), context, offset=2 * offset)
+ if len(tp.items()) > max_items:
+ self.note('<{} more overload(s) not shown>'.format(len(tp.items()) - max_items),
+ context, offset=2 * offset)
+
+ def print_more(self, conflicts: Sequence[Any], context: Context,
+ offset: int, max_items: int) -> None:
+ if len(conflicts) > max_items:
+ self.note('<{} more conflict(s) not shown>'
+ .format(len(conflicts) - max_items),
+ context, offset=offset)
+
+ def pretty_callable(self, tp: CallableType) -> str:
+ """Return a nice easily-readable representation of a callable type.
+ For example:
+ def [T <: int] f(self, x: int, y: T) -> None
+ """
+ s = ''
+ asterisk = False
+ for i in range(len(tp.arg_types)):
+ if s:
+ s += ', '
+ if tp.arg_kinds[i] in (ARG_NAMED, ARG_NAMED_OPT) and not asterisk:
+ s += '*, '
+ asterisk = True
+ if tp.arg_kinds[i] == ARG_STAR:
+ s += '*'
+ asterisk = True
+ if tp.arg_kinds[i] == ARG_STAR2:
+ s += '**'
+ name = tp.arg_names[i]
+ if name:
+ s += name + ': '
+ s += strip_quotes(self.format(tp.arg_types[i]))
+ if tp.arg_kinds[i] in (ARG_OPT, ARG_NAMED_OPT):
+ s += ' = ...'
+
+ # If we got a "special arg" (i.e: self, cls, etc...), prepend it to the arg list
+ if tp.definition is not None and tp.definition.name() is not None:
+ definition_args = getattr(tp.definition, 'arg_names')
+ if definition_args and tp.arg_names != definition_args \
+ and len(definition_args) > 0:
+ if s:
+ s = ', ' + s
+ s = definition_args[0] + s
+ s = '{}({})'.format(tp.definition.name(), s)
+ else:
+ s = '({})'.format(s)
+
+ s += ' -> ' + strip_quotes(self.format(tp.ret_type))
+ if tp.variables:
+ tvars = []
+ for tvar in tp.variables:
+ if (tvar.upper_bound and isinstance(tvar.upper_bound, Instance) and
+ tvar.upper_bound.type.fullname() != 'builtins.object'):
+ tvars.append('{} <: {}'.format(tvar.name,
+ strip_quotes(self.format(tvar.upper_bound))))
+ elif tvar.values:
+ tvars.append('{} in ({})'
+ .format(tvar.name, ', '.join([strip_quotes(self.format(tp))
+ for tp in tvar.values])))
+ else:
+ tvars.append(tvar.name)
+ s = '[{}] {}'.format(', '.join(tvars), s)
+ return 'def {}'.format(s)
+
+
+def variance_string(variance: int) -> str:
+ if variance == COVARIANT:
+ return 'covariant'
+ elif variance == CONTRAVARIANT:
+ return 'contravariant'
+ else:
+ return 'invariant'
+
+
+def get_missing_protocol_members(left: Instance, right: Instance) -> List[str]:
+ """Find all protocol members of 'right' that are not implemented
+ (i.e. completely missing) in 'left'.
+ """
+ from mypy.subtypes import find_member
+ assert right.type.is_protocol
+ missing = [] # type: List[str]
+ for member in right.type.protocol_members:
+ if not find_member(member, left, left):
+ missing.append(member)
+ return missing
+
+
+def get_conflict_protocol_types(left: Instance, right: Instance) -> List[Tuple[str, Type, Type]]:
+ """Find members that are defined in 'left' but have incompatible types.
+ Return them as a list of ('member', 'got', 'expected').
+ """
+ from mypy.subtypes import find_member, is_subtype, get_member_flags, IS_SETTABLE
+ assert right.type.is_protocol
+ conflicts = [] # type: List[Tuple[str, Type, Type]]
+ for member in right.type.protocol_members:
+ if member in ('__init__', '__new__'):
+ continue
+ supertype = find_member(member, right, left)
+ assert supertype is not None
+ subtype = find_member(member, left, left)
+ if not subtype:
+ continue
+ is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True)
+ if IS_SETTABLE in get_member_flags(member, right.type):
+ is_compat = is_compat and is_subtype(supertype, subtype)
+ if not is_compat:
+ conflicts.append((member, subtype, supertype))
+ return conflicts
+
+
+def get_bad_protocol_flags(left: Instance, right: Instance
+ ) -> List[Tuple[str, Set[int], Set[int]]]:
+ """Return all incompatible attribute flags for members that are present in both
+ 'left' and 'right'.
+ """
+ from mypy.subtypes import (find_member, get_member_flags,
+ IS_SETTABLE, IS_CLASSVAR, IS_CLASS_OR_STATIC)
+ assert right.type.is_protocol
+ all_flags = [] # type: List[Tuple[str, Set[int], Set[int]]]
+ for member in right.type.protocol_members:
+ if find_member(member, left, left):
+ item = (member,
+ get_member_flags(member, left.type),
+ get_member_flags(member, right.type))
+ all_flags.append(item)
+ bad_flags = []
+ for name, subflags, superflags in all_flags:
+ if (IS_CLASSVAR in subflags and IS_CLASSVAR not in superflags or
+ IS_CLASSVAR in superflags and IS_CLASSVAR not in subflags or
+ IS_SETTABLE in superflags and IS_SETTABLE not in subflags or
+ IS_CLASS_OR_STATIC in superflags and IS_CLASS_OR_STATIC not in subflags):
+ bad_flags.append((name, subflags, superflags))
+ return bad_flags
+
def capitalize(s: str) -> str:
"""Capitalize the first character of a string."""
diff --git a/mypy/nodes.py b/mypy/nodes.py
index 16d57c418502..400339c083c1 100644
--- a/mypy/nodes.py
+++ b/mypy/nodes.py
@@ -669,6 +669,7 @@ class Var(SymbolNode):
is_property = False
is_settable_property = False
is_classvar = False
+ is_abstract_var = False
# Set to true when this variable refers to a module we were unable to
# parse for some reason (eg a silenced module)
is_suppressed_import = False
@@ -676,7 +677,7 @@ class Var(SymbolNode):
FLAGS = [
'is_self', 'is_ready', 'is_initialized_in_class', 'is_staticmethod',
'is_classmethod', 'is_property', 'is_settable_property', 'is_suppressed_import',
- 'is_classvar'
+ 'is_classvar', 'is_abstract_var'
]
def __init__(self, name: str, type: 'Optional[mypy.types.Type]' = None) -> None:
@@ -1932,9 +1933,13 @@ class TempNode(Expression):
"""
type = None # type: mypy.types.Type
+ # Is this TempNode used to indicate absence of a right hand side in an annotated assignment?
+ # (e.g. for 'x: int' the rvalue is TempNode(AnyType(TypeOfAny.special_form), no_rhs=True))
+ no_rhs = False # type: bool
- def __init__(self, typ: 'mypy.types.Type') -> None:
+ def __init__(self, typ: 'mypy.types.Type', no_rhs: bool = False) -> None:
self.type = typ
+ self.no_rhs = no_rhs
def __repr__(self) -> str:
return 'TempNode(%s)' % str(self.type)
@@ -1972,7 +1977,53 @@ class is generic then it will be a type constructor of higher kind.
subtypes = None # type: Set[TypeInfo] # Direct subclasses encountered so far
names = None # type: SymbolTable # Names defined directly in this type
is_abstract = False # Does the class have any abstract attributes?
+ is_protocol = False # Is this a protocol class?
+ runtime_protocol = False # Does this protocol support isinstance checks?
abstract_attributes = None # type: List[str]
+ # Protocol members are names of all attributes/methods defined in a protocol
+ # and in all its supertypes (except for 'object').
+ protocol_members = None # type: List[str]
+
+ # The attributes 'assuming' and 'assuming_proper' represent structural subtype matrices.
+ #
+ # In languages with structural subtyping, one can keep a global subtype matrix like this:
+ # . A B C .
+ # A 1 0 0
+ # B 1 1 1
+ # C 1 0 1
+ # .
+ # where 1 indicates that the type in corresponding row is a subtype of the type
+ # in corresponding column. This matrix typically starts filled with all 1's and
+ # a typechecker tries to "disprove" every subtyping relation using atomic (or nominal) types.
+ # However, we don't want to keep this huge global state. Instead, we keep the subtype
+ # information in the form of list of pairs (subtype, supertype) shared by all 'Instance's
+ # with given supertype's TypeInfo. When we enter a subtype check we push a pair in this list
+ # thus assuming that we started with 1 in corresponding matrix element. Such algorithm allows
+ # to treat recursive and mutually recursive protocols and other kinds of complex situations.
+ #
+ # If concurrent/parallel type checking will be added in future,
+ # then there should be one matrix per thread/process to avoid false negatives
+ # during the type checking phase.
+ assuming = None # type: List[Tuple[mypy.types.Instance, mypy.types.Instance]]
+ assuming_proper = None # type: List[Tuple[mypy.types.Instance, mypy.types.Instance]]
+ # Ditto for temporary 'inferring' stack of recursive constraint inference.
+ # It contains Instance's of protocol types that appeared as an argument to
+ # constraints.infer_constraints(). We need 'inferring' to avoid infinite recursion for
+ # recursive and mutually recursive protocols.
+ #
+ # We make 'assuming' and 'inferring' attributes here instead of passing they as kwargs,
+ # since this would require to pass them in many dozens of calls. In particular,
+ # there is a dependency infer_constraint -> is_subtype -> is_callable_subtype ->
+ # -> infer_constraints.
+ inferring = None # type: List[mypy.types.Instance]
+ # '_cache' and '_cache_proper' are subtype caches, implemented as sets of pairs
+ # of (subtype, supertype), where supertypes are instances of given TypeInfo.
+ # We need the caches, since subtype checks for structural types are very slow.
+ _cache = None # type: Set[Tuple[mypy.types.Type, mypy.types.Type]]
+ _cache_proper = None # type: Set[Tuple[mypy.types.Type, mypy.types.Type]]
+ # 'inferring' and 'assuming' can't be also made sets, since we need to use
+ # is_same_type to correctly treat unions.
+
# Classes inheriting from Enum shadow their true members with a __getattr__, so we
# have to treat them as a special case.
is_enum = False
@@ -2016,7 +2067,7 @@ class is generic then it will be a type constructor of higher kind.
FLAGS = [
'is_abstract', 'is_enum', 'fallback_to_any', 'is_named_tuple',
- 'is_newtype'
+ 'is_newtype', 'is_protocol', 'runtime_protocol'
]
def __init__(self, names: 'SymbolTable', defn: ClassDef, module_name: str) -> None:
@@ -2032,6 +2083,11 @@ def __init__(self, names: 'SymbolTable', defn: ClassDef, module_name: str) -> No
self._fullname = defn.fullname
self.is_abstract = False
self.abstract_attributes = []
+ self.assuming = []
+ self.assuming_proper = []
+ self.inferring = []
+ self._cache = set()
+ self._cache_proper = set()
self.add_type_vars()
def add_type_vars(self) -> None:
@@ -2051,6 +2107,9 @@ def is_generic(self) -> bool:
return len(self.type_vars) > 0
def get(self, name: str) -> 'Optional[SymbolTableNode]':
+ if self.mro is None: # Might be because of a previous error.
+ return None
+
for cls in self.mro:
n = cls.names.get(name)
if n:
@@ -2063,6 +2122,21 @@ def get_containing_type_info(self, name: str) -> 'Optional[TypeInfo]':
return cls
return None
+ def record_subtype_cache_entry(self, left: 'mypy.types.Instance',
+ right: 'mypy.types.Instance',
+ proper_subtype: bool = False) -> None:
+ if proper_subtype:
+ self._cache_proper.add((left, right))
+ else:
+ self._cache.add((left, right))
+
+ def is_cached_subtype_check(self, left: 'mypy.types.Instance',
+ right: 'mypy.types.Instance',
+ proper_subtype: bool = False) -> bool:
+ if not proper_subtype:
+ return (left, right) in self._cache
+ return (left, right) in self._cache_proper
+
def __getitem__(self, name: str) -> 'SymbolTableNode':
n = self.get(name)
if n:
@@ -2203,6 +2277,7 @@ def serialize(self) -> JsonDict:
'names': self.names.serialize(self.fullname()),
'defn': self.defn.serialize(),
'abstract_attributes': self.abstract_attributes,
+ 'protocol_members': self.protocol_members,
'type_vars': self.type_vars,
'bases': [b.serialize() for b in self.bases],
'_promote': None if self._promote is None else self._promote.serialize(),
@@ -2226,6 +2301,7 @@ def deserialize(cls, data: JsonDict) -> 'TypeInfo':
ti._fullname = data['fullname']
# TODO: Is there a reason to reconstruct ti.subtypes?
ti.abstract_attributes = data['abstract_attributes']
+ ti.protocol_members = data['protocol_members']
ti.type_vars = data['type_vars']
ti.bases = [mypy.types.Instance.deserialize(b) for b in data['bases']]
ti._promote = (None if data['_promote'] is None
diff --git a/mypy/sametypes.py b/mypy/sametypes.py
index 0531ecc2d474..cba80e1ef825 100644
--- a/mypy/sametypes.py
+++ b/mypy/sametypes.py
@@ -68,8 +68,8 @@ def visit_erased_type(self, left: ErasedType) -> bool:
# We can get here when isinstance is used inside a lambda
# whose type is being inferred. In any event, we have no reason
# to think that an ErasedType will end up being the same as
- # any other type, even another ErasedType.
- return False
+ # any other type, except another ErasedType (for protocols).
+ return isinstance(self.right, ErasedType)
def visit_deleted_type(self, left: DeletedType) -> bool:
return isinstance(self.right, DeletedType)
diff --git a/mypy/semanal.py b/mypy/semanal.py
index 6f1ee503b19c..e413c7fd5882 100644
--- a/mypy/semanal.py
+++ b/mypy/semanal.py
@@ -543,11 +543,17 @@ def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
if defn.impl is not None:
assert defn.impl is defn.items[-1]
defn.items = defn.items[:-1]
-
elif not self.is_stub_file and not non_overload_indexes:
- self.fail(
- "An overloaded function outside a stub file must have an implementation",
- defn)
+ if not (self.is_class_scope() and self.type.is_protocol):
+ self.fail(
+ "An overloaded function outside a stub file must have an implementation",
+ defn)
+ else:
+ for item in defn.items:
+ if isinstance(item, Decorator):
+ item.func.is_abstract = True
+ else:
+ item.is_abstract = True
if types:
defn.type = Overloaded(types)
@@ -671,6 +677,7 @@ def visit_class_def(self, defn: ClassDef) -> None:
@contextmanager
def analyze_class_body(self, defn: ClassDef) -> Iterator[bool]:
with self.tvar_scope_frame(self.tvar_scope.class_frame()):
+ is_protocol = self.detect_protocol_base(defn)
self.clean_up_bases_and_infer_type_variables(defn)
self.analyze_class_keywords(defn)
if self.analyze_typeddict_classdef(defn):
@@ -710,13 +717,12 @@ def analyze_class_body(self, defn: ClassDef) -> Iterator[bool]:
self.setup_class_def_analysis(defn)
self.analyze_base_classes(defn)
self.analyze_metaclass(defn)
-
+ defn.info.is_protocol = is_protocol
+ defn.info.runtime_protocol = False
for decorator in defn.decorators:
self.analyze_class_decorator(defn, decorator)
-
self.enter_class(defn.info)
yield True
-
self.calculate_abstract_status(defn.info)
self.setup_type_promotion(defn)
@@ -743,6 +749,12 @@ def leave_class(self) -> None:
def analyze_class_decorator(self, defn: ClassDef, decorator: Expression) -> None:
decorator.accept(self)
+ if (isinstance(decorator, RefExpr) and
+ decorator.fullname in ('typing.runtime', 'typing_extensions.runtime')):
+ if defn.info.is_protocol:
+ defn.info.runtime_protocol = True
+ else:
+ self.fail('@runtime can only be used with protocol classes', defn)
def calculate_abstract_status(self, typ: TypeInfo) -> None:
"""Calculate abstract status of a class.
@@ -768,6 +780,10 @@ def calculate_abstract_status(self, typ: TypeInfo) -> None:
if fdef.is_abstract and name not in concrete:
typ.is_abstract = True
abstract.append(name)
+ elif isinstance(node, Var):
+ if node.is_abstract_var and name not in concrete:
+ typ.is_abstract = True
+ abstract.append(name)
concrete.add(name)
typ.abstract_attributes = sorted(abstract)
@@ -790,6 +806,21 @@ def setup_type_promotion(self, defn: ClassDef) -> None:
promote_target = self.named_type_or_none(promotions[defn.fullname])
defn.info._promote = promote_target
+ def detect_protocol_base(self, defn: ClassDef) -> bool:
+ for base_expr in defn.base_type_exprs:
+ try:
+ base = expr_to_unanalyzed_type(base_expr)
+ except TypeTranslationError:
+ continue # This will be reported later
+ if not isinstance(base, UnboundType):
+ continue
+ sym = self.lookup_qualified(base.name, base)
+ if sym is None or sym.node is None:
+ continue
+ if sym.node.fullname() in ('typing.Protocol', 'typing_extensions.Protocol'):
+ return True
+ return False
+
def clean_up_bases_and_infer_type_variables(self, defn: ClassDef) -> None:
"""Remove extra base classes such as Generic and infer type vars.
@@ -817,17 +848,26 @@ def clean_up_bases_and_infer_type_variables(self, defn: ClassDef) -> None:
tvars = self.analyze_typevar_declaration(base)
if tvars is not None:
if declared_tvars:
- self.fail('Duplicate Generic in bases', defn)
+ self.fail('Only single Generic[...] or Protocol[...] can be in bases', defn)
removed.append(i)
declared_tvars.extend(tvars)
+ if isinstance(base, UnboundType):
+ sym = self.lookup_qualified(base.name, base)
+ if sym is not None and sym.node is not None:
+ if (sym.node.fullname() in ('typing.Protocol',
+ 'typing_extensions.Protocol') and
+ i not in removed):
+ # also remove bare 'Protocol' bases
+ removed.append(i)
all_tvars = self.get_all_bases_tvars(defn, removed)
if declared_tvars:
if len(remove_dups(declared_tvars)) < len(declared_tvars):
- self.fail("Duplicate type variables in Generic[...]", defn)
+ self.fail("Duplicate type variables in Generic[...] or Protocol[...]", defn)
declared_tvars = remove_dups(declared_tvars)
if not set(all_tvars).issubset(set(declared_tvars)):
- self.fail("If Generic[...] is present it should list all type variables", defn)
+ self.fail("If Generic[...] or Protocol[...] is present"
+ " it should list all type variables", defn)
# In case of error, Generic tvars will go first
declared_tvars = remove_dups(declared_tvars + all_tvars)
else:
@@ -849,7 +889,9 @@ def analyze_typevar_declaration(self, t: Type) -> Optional[TypeVarList]:
sym = self.lookup_qualified(unbound.name, unbound)
if sym is None or sym.node is None:
return None
- if sym.node.fullname() == 'typing.Generic':
+ if (sym.node.fullname() == 'typing.Generic' or
+ sym.node.fullname() == 'typing.Protocol' and t.args or
+ sym.node.fullname() == 'typing_extensions.Protocol' and t.args):
tvars = [] # type: TypeVarList
for arg in unbound.args:
tvar = self.analyze_unbound_tvar(arg)
@@ -1594,10 +1636,16 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
if s.type:
allow_tuple_literal = isinstance(s.lvalues[-1], (TupleExpr, ListExpr))
s.type = self.anal_type(s.type, allow_tuple_literal=allow_tuple_literal)
+ if (self.type and self.type.is_protocol and isinstance(lval, NameExpr) and
+ isinstance(s.rvalue, TempNode) and s.rvalue.no_rhs):
+ if isinstance(lval.node, Var):
+ lval.node.is_abstract_var = True
else:
- # Set the type if the rvalue is a simple literal.
- if (s.type is None and len(s.lvalues) == 1 and
- isinstance(s.lvalues[0], NameExpr)):
+ if (any(isinstance(lv, NameExpr) and lv.is_def for lv in s.lvalues) and
+ self.type and self.type.is_protocol and not self.is_func_scope()):
+ self.fail('All protocol members must have explicitly declared types', s)
+ # Set the type if the rvalue is a simple literal (even if the above error occurred).
+ if len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr):
if s.lvalues[0].is_def:
s.type = self.analyze_simple_literal_type(s.rvalue)
if s.type:
@@ -1832,18 +1880,22 @@ def analyze_tuple_or_list_lvalue(self, lval: Union[ListExpr, TupleExpr],
def analyze_member_lvalue(self, lval: MemberExpr) -> None:
lval.accept(self)
- if (self.is_self_member_ref(lval) and
- self.type.get(lval.name) is None):
- # Implicit attribute definition in __init__.
- lval.is_def = True
- v = Var(lval.name)
- v.set_line(lval)
- v._fullname = self.qualified_name(lval.name)
- v.info = self.type
- v.is_ready = False
- lval.def_var = v
- lval.node = v
- self.type.names[lval.name] = SymbolTableNode(MDEF, v, implicit=True)
+ if self.is_self_member_ref(lval):
+ node = self.type.get(lval.name)
+ if node is None or isinstance(node.node, Var) and node.node.is_abstract_var:
+ if self.type.is_protocol and node is None:
+ self.fail("Protocol members cannot be defined via assignment to self", lval)
+ else:
+ # Implicit attribute definition in __init__.
+ lval.is_def = True
+ v = Var(lval.name)
+ v.set_line(lval)
+ v._fullname = self.qualified_name(lval.name)
+ v.info = self.type
+ v.is_ready = False
+ lval.def_var = v
+ lval.node = v
+ self.type.names[lval.name] = SymbolTableNode(MDEF, v, implicit=True)
self.check_lvalue_validity(lval.node, lval)
def is_self_member_ref(self, memberexpr: MemberExpr) -> bool:
@@ -1906,6 +1958,8 @@ def process_newtype_declaration(self, s: AssignmentStmt) -> None:
newtype_class_info = self.build_newtype_typeinfo(name, old_type, old_type.fallback)
newtype_class_info.tuple_type = old_type
elif isinstance(old_type, Instance):
+ if old_type.type.is_protocol:
+ self.fail("NewType cannot be used with protocol classes", s)
newtype_class_info = self.build_newtype_typeinfo(name, old_type, old_type)
else:
message = "Argument 2 to NewType(...) must be subclassable (got {})"
@@ -3778,6 +3832,11 @@ def visit_file(self, file: MypyFile, fnam: str, mod_id: str, options: Options) -
('False', bool_type),
('__debug__', bool_type),
])
+ else:
+ # We are running tests without 'bool' in builtins.
+ # TODO: Find a permanent solution to this problem.
+ # Maybe add 'bool' to all fixtures?
+ literal_types.append(('True', AnyType(TypeOfAny.special_form)))
for name, typ in literal_types:
v = Var(name, typ)
@@ -4018,12 +4077,18 @@ def visit_class_def(self, tdef: ClassDef) -> None:
if not tdef.info.is_named_tuple:
for type in tdef.info.bases:
self.analyze(type)
+ if tdef.info.is_protocol:
+ if not isinstance(type, Instance) or not type.type.is_protocol:
+ if type.type.fullname() != 'builtins.object':
+ self.fail('All bases of a protocol must be protocols', tdef)
# Recompute MRO now that we have analyzed all modules, to pick
# up superclasses of bases imported from other modules in an
# import loop. (Only do so if we succeeded the first time.)
if tdef.info.mro:
tdef.info.mro = [] # Force recomputation
calculate_class_mro(tdef, self.fail_blocker)
+ if tdef.info.is_protocol:
+ add_protocol_members(tdef.info)
if tdef.analyzed is not None:
if isinstance(tdef.analyzed, TypedDictExpr):
self.analyze(tdef.analyzed.info.typeddict_type)
@@ -4141,6 +4206,16 @@ def builtin_type(self, name: str, args: List[Type] = None) -> Instance:
return Instance(node, [any_type] * len(node.defn.type_vars))
+def add_protocol_members(typ: TypeInfo) -> None:
+ members = set() # type: Set[str]
+ if typ.mro:
+ for base in typ.mro[:-1]: # we skip "object" since everyone implements it
+ if base.is_protocol:
+ for name in base.names:
+ members.add(name)
+ typ.protocol_members = sorted(list(members))
+
+
def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike:
if isinstance(sig, CallableType):
return sig.copy_modified(arg_types=[new] + sig.arg_types[1:])
diff --git a/mypy/subtypes.py b/mypy/subtypes.py
index 97701598921c..5dfbc7c405bc 100644
--- a/mypy/subtypes.py
+++ b/mypy/subtypes.py
@@ -1,9 +1,11 @@
-from typing import List, Optional, Dict, Callable, cast
+from typing import List, Optional, Dict, Callable, Tuple, Iterator, Set, Union, cast
+from contextlib import contextmanager
from mypy.types import (
- Type, AnyType, UnboundType, TypeVisitor, FormalArgument, NoneTyp, Instance, TypeVarType,
- CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, TypeList,
- PartialType, DeletedType, UninhabitedType, TypeType, is_named_instance, TypeOfAny
+ Type, AnyType, UnboundType, TypeVisitor, FormalArgument, NoneTyp, function_type,
+ Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded,
+ ErasedType, TypeList, PartialType, DeletedType, UninhabitedType, TypeType, is_named_instance,
+ FunctionLike, TypeOfAny
)
import mypy.applytype
import mypy.constraints
@@ -12,15 +14,22 @@
# import mypy.solve
from mypy import messages, sametypes
from mypy.nodes import (
- CONTRAVARIANT, COVARIANT,
- ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2,
+ FuncBase, Var, Decorator, OverloadedFuncDef, TypeInfo, CONTRAVARIANT, COVARIANT,
+ ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2
)
from mypy.maptype import map_instance_to_supertype
+from mypy.expandtype import expand_type_by_instance
from mypy.sametypes import is_same_type
from mypy import experiments
+# Flags for detected protocol members
+IS_SETTABLE = 1
+IS_CLASSVAR = 2
+IS_CLASS_OR_STATIC = 3
+
+
TypeParameterChecker = Callable[[Type, Type, int], bool]
@@ -35,7 +44,8 @@ def check_type_parameter(lefta: Type, righta: Type, variance: int) -> bool:
def is_subtype(left: Type, right: Type,
type_parameter_checker: TypeParameterChecker = check_type_parameter,
- *, ignore_pos_arg_names: bool = False) -> bool:
+ *, ignore_pos_arg_names: bool = False,
+ ignore_declared_variance: bool = False) -> bool:
"""Is 'left' subtype of 'right'?
Also consider Any to be a subtype of any type, and vice versa. This
@@ -69,7 +79,8 @@ def is_subtype(left: Type, right: Type,
return True
# otherwise, fall through
return left.accept(SubtypeVisitor(right, type_parameter_checker,
- ignore_pos_arg_names=ignore_pos_arg_names))
+ ignore_pos_arg_names=ignore_pos_arg_names,
+ ignore_declared_variance=ignore_declared_variance))
def is_subtype_ignoring_tvars(left: Type, right: Type) -> bool:
@@ -93,10 +104,12 @@ class SubtypeVisitor(TypeVisitor[bool]):
def __init__(self, right: Type,
type_parameter_checker: TypeParameterChecker,
- *, ignore_pos_arg_names: bool = False) -> None:
+ *, ignore_pos_arg_names: bool = False,
+ ignore_declared_variance: bool = False) -> None:
self.right = right
self.check_type_parameter = type_parameter_checker
self.ignore_pos_arg_names = ignore_pos_arg_names
+ self.ignore_declared_variance = ignore_declared_variance
# visit_x(left) means: is left (which is an instance of X) a subtype of
# right?
@@ -130,24 +143,34 @@ def visit_instance(self, left: Instance) -> bool:
if isinstance(right, TupleType) and right.fallback.type.is_enum:
return is_subtype(left, right.fallback)
if isinstance(right, Instance):
- # NOTO: left.type.mro may be None in quick mode if there
+ if right.type.is_cached_subtype_check(left, right):
+ return True
+ # NOTE: left.type.mro may be None in quick mode if there
# was an error somewhere.
if left.type.mro is not None:
for base in left.type.mro:
+ # TODO: Also pass recursively ignore_declared_variance
if base._promote and is_subtype(
base._promote, self.right, self.check_type_parameter,
ignore_pos_arg_names=self.ignore_pos_arg_names):
+ right.type.record_subtype_cache_entry(left, right)
return True
rname = right.type.fullname()
- if not left.type.has_base(rname) and rname != 'builtins.object':
- return False
-
- # Map left type to corresponding right instances.
- t = map_instance_to_supertype(left, right.type)
-
- return all(self.check_type_parameter(lefta, righta, tvar.variance)
- for lefta, righta, tvar in
- zip(t.args, right.args, right.type.defn.type_vars))
+ # Always try a nominal check if possible,
+ # there might be errors that a user wants to silence *once*.
+ if ((left.type.has_base(rname) or rname == 'builtins.object') and
+ not self.ignore_declared_variance):
+ # Map left type to corresponding right instances.
+ t = map_instance_to_supertype(left, right.type)
+ nominal = all(self.check_type_parameter(lefta, righta, tvar.variance)
+ for lefta, righta, tvar in
+ zip(t.args, right.args, right.type.defn.type_vars))
+ if nominal:
+ right.type.record_subtype_cache_entry(left, right)
+ return nominal
+ if right.type.is_protocol and is_protocol_implementation(left, right):
+ return True
+ return False
if isinstance(right, TypeType):
item = right.item
if isinstance(item, TupleType):
@@ -163,7 +186,14 @@ def visit_instance(self, left: Instance) -> bool:
and is_named_instance(item, 'enum.Enum')):
return True
return is_named_instance(item, 'builtins.object')
- return False
+ if isinstance(right, CallableType):
+ # Special case: Instance can be a subtype of Callable.
+ call = find_member('__call__', left, left)
+ if call:
+ return is_subtype(call, right)
+ return False
+ else:
+ return False
def visit_type_var(self, left: TypeVarType) -> bool:
right = self.right
@@ -302,6 +332,190 @@ def visit_type_type(self, left: TypeType) -> bool:
return False
+@contextmanager
+def pop_on_exit(stack: List[Tuple[Instance, Instance]],
+ left: Instance, right: Instance) -> Iterator[None]:
+ stack.append((left, right))
+ yield
+ stack.pop()
+
+
+def is_protocol_implementation(left: Instance, right: Instance,
+ proper_subtype: bool = False) -> bool:
+ """Check whether 'left' implements the protocol 'right'.
+
+ If 'proper_subtype' is True, then check for a proper subtype.
+ Treat recursive protocols by using the 'assuming' structural subtype matrix
+ (in sparse representation, i.e. as a list of pairs (subtype, supertype)),
+ see also comment in nodes.TypeInfo. When we enter a check for classes
+ (A, P), defined as following::
+
+ class P(Protocol):
+ def f(self) -> P: ...
+ class A:
+ def f(self) -> A: ...
+
+ this results in A being a subtype of P without infinite recursion.
+ On every false result, we pop the assumption, thus avoiding an infinite recursion
+ as well.
+ """
+ assert right.type.is_protocol
+ assuming = right.type.assuming_proper if proper_subtype else right.type.assuming
+ for (l, r) in reversed(assuming):
+ if sametypes.is_same_type(l, left) and sametypes.is_same_type(r, right):
+ return True
+ with pop_on_exit(assuming, left, right):
+ for member in right.type.protocol_members:
+ # nominal subtyping currently ignores '__init__' and '__new__' signatures
+ if member in ('__init__', '__new__'):
+ continue
+ # The third argument below indicates to what self type is bound.
+ # We always bind self to the subtype. (Similarly to nominal types).
+ supertype = find_member(member, right, left)
+ assert supertype is not None
+ subtype = find_member(member, left, left)
+ # Useful for debugging:
+ # print(member, 'of', left, 'has type', subtype)
+ # print(member, 'of', right, 'has type', supertype)
+ if not subtype:
+ return False
+ if not proper_subtype:
+ # Nominal check currently ignores arg names
+ is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True)
+ else:
+ is_compat = is_proper_subtype(subtype, supertype)
+ if not is_compat:
+ return False
+ if isinstance(subtype, NoneTyp) and member.startswith('__') and member.endswith('__'):
+ # We want __hash__ = None idiom to work even without --strict-optional
+ return False
+ subflags = get_member_flags(member, left.type)
+ superflags = get_member_flags(member, right.type)
+ if IS_SETTABLE in superflags:
+ # Check opposite direction for settable attributes.
+ if not is_subtype(supertype, subtype):
+ return False
+ if (IS_CLASSVAR in subflags) != (IS_CLASSVAR in superflags):
+ return False
+ if IS_SETTABLE in superflags and IS_SETTABLE not in subflags:
+ return False
+ # This rule is copied from nominal check in checker.py
+ if IS_CLASS_OR_STATIC in superflags and IS_CLASS_OR_STATIC not in subflags:
+ return False
+ right.type.record_subtype_cache_entry(left, right, proper_subtype)
+ return True
+
+
+def find_member(name: str, itype: Instance, subtype: Type) -> Optional[Type]:
+ """Find the type of member by 'name' in 'itype's TypeInfo.
+
+ Fin the member type after applying type arguments from 'itype', and binding
+ 'self' to 'subtype'. Return None if member was not found.
+ """
+ # TODO: this code shares some logic with checkmember.analyze_member_access,
+ # consider refactoring.
+ info = itype.type
+ method = info.get_method(name)
+ if method:
+ if method.is_property:
+ assert isinstance(method, OverloadedFuncDef)
+ dec = method.items[0]
+ assert isinstance(dec, Decorator)
+ return find_node_type(dec.var, itype, subtype)
+ return find_node_type(method, itype, subtype)
+ else:
+ # don't have such method, maybe variable or decorator?
+ node = info.get(name)
+ if not node:
+ v = None
+ else:
+ v = node.node
+ if isinstance(v, Decorator):
+ v = v.var
+ if isinstance(v, Var):
+ return find_node_type(v, itype, subtype)
+ if not v and name not in ['__getattr__', '__setattr__', '__getattribute__']:
+ for method_name in ('__getattribute__', '__getattr__'):
+ # Normally, mypy assumes that instances that define __getattr__ have all
+ # attributes with the corresponding return type. If this will produce
+ # many false negatives, then this could be prohibited for
+ # structural subtyping.
+ method = info.get_method(method_name)
+ if method and method.info.fullname() != 'builtins.object':
+ getattr_type = find_node_type(method, itype, subtype)
+ if isinstance(getattr_type, CallableType):
+ return getattr_type.ret_type
+ if itype.type.fallback_to_any:
+ return AnyType(TypeOfAny.special_form)
+ return None
+
+
+def get_member_flags(name: str, info: TypeInfo) -> Set[int]:
+ """Detect whether a member 'name' is settable, whether it is an
+ instance or class variable, and whether it is class or static method.
+
+ The flags are defined as following:
+ * IS_SETTABLE: whether this attribute can be set, not set for methods and
+ non-settable properties;
+ * IS_CLASSVAR: set if the variable is annotated as 'x: ClassVar[t]';
+ * IS_CLASS_OR_STATIC: set for methods decorated with @classmethod or
+ with @staticmethod.
+ """
+ method = info.get_method(name)
+ setattr_meth = info.get_method('__setattr__')
+ if method:
+ # this could be settable property
+ if method.is_property:
+ assert isinstance(method, OverloadedFuncDef)
+ dec = method.items[0]
+ assert isinstance(dec, Decorator)
+ if dec.var.is_settable_property or setattr_meth:
+ return {IS_SETTABLE}
+ return set()
+ node = info.get(name)
+ if not node:
+ if setattr_meth:
+ return {IS_SETTABLE}
+ return set()
+ v = node.node
+ if isinstance(v, Decorator):
+ if v.var.is_staticmethod or v.var.is_classmethod:
+ return {IS_CLASS_OR_STATIC}
+ # just a variable
+ if isinstance(v, Var):
+ flags = {IS_SETTABLE}
+ if v.is_classvar:
+ flags.add(IS_CLASSVAR)
+ return flags
+ return set()
+
+
+def find_node_type(node: Union[Var, FuncBase], itype: Instance, subtype: Type) -> Type:
+ """Find type of a variable or method 'node' (maybe also a decorated method).
+ Apply type arguments from 'itype', and bind 'self' to 'subtype'.
+ """
+ from mypy.checkmember import bind_self
+ if isinstance(node, FuncBase):
+ typ = function_type(node,
+ fallback=Instance(itype.type.mro[-1], [])) # type: Optional[Type]
+ else:
+ typ = node.type
+ if typ is None:
+ return AnyType(TypeOfAny.from_error)
+ # We don't need to bind 'self' for static methods, since there is no 'self'.
+ if isinstance(node, FuncBase) or isinstance(typ, FunctionLike) and not node.is_staticmethod:
+ assert isinstance(typ, FunctionLike)
+ signature = bind_self(typ, subtype)
+ if node.is_property:
+ assert isinstance(signature, CallableType)
+ typ = signature.ret_type
+ else:
+ typ = signature
+ itype = map_instance_to_supertype(itype, node.info)
+ typ = expand_type_by_instance(typ, itype)
+ return typ
+
+
def is_callable_subtype(left: CallableType, right: CallableType,
ignore_return: bool = False,
ignore_pos_arg_names: bool = False,
@@ -548,8 +762,11 @@ def restrict_subtype_away(t: Type, s: Type) -> Type:
if isinstance(t, UnionType):
# Since runtime type checks will ignore type arguments, erase the types.
erased_s = erase_type(s)
+ # TODO: Implement more robust support for runtime isinstance() checks,
+ # see issue #3827
new_items = [item for item in t.relevant_items()
- if (not is_proper_subtype(erase_type(item), erased_s)
+ if (not (is_proper_subtype(erase_type(item), erased_s) or
+ is_proper_subtype(item, erased_s))
or isinstance(item, AnyType))]
return UnionType.make_union(new_items)
else:
@@ -601,26 +818,38 @@ def visit_deleted_type(self, left: DeletedType) -> bool:
def visit_instance(self, left: Instance) -> bool:
right = self.right
if isinstance(right, Instance):
+ if right.type.is_cached_subtype_check(left, right, proper_subtype=True):
+ return True
for base in left.type.mro:
if base._promote and is_proper_subtype(base._promote, right):
+ right.type.record_subtype_cache_entry(left, right, proper_subtype=True)
return True
- if not left.type.has_base(right.type.fullname()):
- return False
-
- def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool:
- if variance == COVARIANT:
- return is_proper_subtype(leftarg, rightarg)
- elif variance == CONTRAVARIANT:
- return is_proper_subtype(rightarg, leftarg)
- else:
- return sametypes.is_same_type(leftarg, rightarg)
-
- # Map left type to corresponding right instances.
- left = map_instance_to_supertype(left, right.type)
-
- return all(check_argument(ta, ra, tvar.variance) for ta, ra, tvar in
- zip(left.args, right.args, right.type.defn.type_vars))
+ if left.type.has_base(right.type.fullname()):
+ def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool:
+ if variance == COVARIANT:
+ return is_proper_subtype(leftarg, rightarg)
+ elif variance == CONTRAVARIANT:
+ return is_proper_subtype(rightarg, leftarg)
+ else:
+ return sametypes.is_same_type(leftarg, rightarg)
+ # Map left type to corresponding right instances.
+ left = map_instance_to_supertype(left, right.type)
+
+ nominal = all(check_argument(ta, ra, tvar.variance) for ta, ra, tvar in
+ zip(left.args, right.args, right.type.defn.type_vars))
+ if nominal:
+ right.type.record_subtype_cache_entry(left, right, proper_subtype=True)
+ return nominal
+ if (right.type.is_protocol and
+ is_protocol_implementation(left, right, proper_subtype=True)):
+ return True
+ return False
+ if isinstance(right, CallableType):
+ call = find_member('__call__', left, left)
+ if call:
+ return is_proper_subtype(call, right)
+ return False
return False
def visit_type_var(self, left: TypeVarType) -> bool:
diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py
index 4ac59bd94f57..6685ed1011af 100644
--- a/mypy/test/testcheck.py
+++ b/mypy/test/testcheck.py
@@ -72,6 +72,7 @@
'check-generic-subtyping.test',
'check-varargs.test',
'check-newsyntax.test',
+ 'check-protocols.test',
'check-underscores.test',
'check-classvar.test',
'check-enum.test',
diff --git a/mypy/types.py b/mypy/types.py
index a94b6e2b7152..192bf08eb35d 100644
--- a/mypy/types.py
+++ b/mypy/types.py
@@ -188,6 +188,15 @@ def __init__(self,
def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_unbound_type(self)
+ def __hash__(self) -> int:
+ return hash((self.name, self.optional, tuple(self.args)))
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, UnboundType):
+ return NotImplemented
+ return (self.name == other.name and self.optional == other.optional and
+ self.args == other.args)
+
def serialize(self) -> JsonDict:
return {'.class': 'UnboundType',
'name': self.name,
@@ -316,6 +325,12 @@ def copy_modified(self,
return AnyType(type_of_any=type_of_any, source_any=original_any,
line=self.line, column=self.column)
+ def __hash__(self) -> int:
+ return hash(AnyType)
+
+ def __eq__(self, other: object) -> bool:
+ return isinstance(other, AnyType)
+
def serialize(self) -> JsonDict:
return {'.class': 'AnyType'}
@@ -350,6 +365,12 @@ def __init__(self, is_noreturn: bool = False, line: int = -1, column: int = -1)
def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_uninhabited_type(self)
+ def __hash__(self) -> int:
+ return hash(UninhabitedType)
+
+ def __eq__(self, other: object) -> bool:
+ return isinstance(other, UninhabitedType)
+
def serialize(self) -> JsonDict:
return {'.class': 'UninhabitedType',
'is_noreturn': self.is_noreturn}
@@ -371,6 +392,12 @@ class NoneTyp(Type):
def __init__(self, line: int = -1, column: int = -1) -> None:
super().__init__(line, column)
+ def __hash__(self) -> int:
+ return hash(NoneTyp)
+
+ def __eq__(self, other: object) -> bool:
+ return isinstance(other, NoneTyp)
+
def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_none_type(self)
@@ -450,6 +477,14 @@ def accept(self, visitor: 'TypeVisitor[T]') -> T:
type_ref = None # type: str
+ def __hash__(self) -> int:
+ return hash((self.type, tuple(self.args)))
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, Instance):
+ return NotImplemented
+ return self.type == other.type and self.args == other.args
+
def serialize(self) -> Union[JsonDict, str]:
assert self.type is not None
type_ref = self.type.fullname()
@@ -512,6 +547,14 @@ def erase_to_union_or_bound(self) -> Type:
else:
return self.upper_bound
+ def __hash__(self) -> int:
+ return hash(self.id)
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, TypeVarType):
+ return NotImplemented
+ return self.id == other.id
+
def serialize(self) -> JsonDict:
assert not self.id.is_meta_var()
return {'.class': 'TypeVarType',
@@ -790,6 +833,23 @@ def type_var_ids(self) -> List[TypeVarId]:
a.append(tv.id)
return a
+ def __hash__(self) -> int:
+ return hash((self.ret_type, self.is_type_obj(),
+ self.is_ellipsis_args, self.name,
+ tuple(self.arg_types), tuple(self.arg_names), tuple(self.arg_kinds)))
+
+ def __eq__(self, other: object) -> bool:
+ if isinstance(other, CallableType):
+ return (self.ret_type == other.ret_type and
+ self.arg_types == other.arg_types and
+ self.arg_names == other.arg_names and
+ self.arg_kinds == other.arg_kinds and
+ self.name == other.name and
+ self.is_type_obj() == other.is_type_obj() and
+ self.is_ellipsis_args == other.is_ellipsis_args)
+ else:
+ return NotImplemented
+
def serialize(self) -> JsonDict:
# TODO: As an optimization, leave out everything related to
# generic functions for non-generic functions.
@@ -872,6 +932,14 @@ def get_name(self) -> Optional[str]:
def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_overloaded(self)
+ def __hash__(self) -> int:
+ return hash(tuple(self.items()))
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, Overloaded):
+ return NotImplemented
+ return self.items() == other.items()
+
def serialize(self) -> JsonDict:
return {'.class': 'Overloaded',
'items': [t.serialize() for t in self.items()],
@@ -913,6 +981,14 @@ def length(self) -> int:
def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_tuple_type(self)
+ def __hash__(self) -> int:
+ return hash((tuple(self.items), self.fallback))
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, TupleType):
+ return NotImplemented
+ return self.items == other.items and self.fallback == other.fallback
+
def serialize(self) -> JsonDict:
return {'.class': 'TupleType',
'items': [t.serialize() for t in self.items],
@@ -965,6 +1041,21 @@ def __init__(self, items: 'OrderedDict[str, Type]', required_keys: Set[str],
def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_typeddict_type(self)
+ def __hash__(self) -> int:
+ return hash((frozenset(self.items.items()), self.fallback,
+ frozenset(self.required_keys)))
+
+ def __eq__(self, other: object) -> bool:
+ if isinstance(other, TypedDictType):
+ if frozenset(self.items.keys()) != frozenset(other.items.keys()):
+ return False
+ for (_, left_item_type, right_item_type) in self.zip(other):
+ if not left_item_type == right_item_type:
+ return False
+ return self.fallback == other.fallback and self.required_keys == other.required_keys
+ else:
+ return NotImplemented
+
def serialize(self) -> JsonDict:
return {'.class': 'TypedDictType',
'items': [[n, t.serialize()] for (n, t) in self.items.items()],
@@ -1062,6 +1153,14 @@ def __init__(self, items: List[Type], line: int = -1, column: int = -1) -> None:
self.can_be_false = any(item.can_be_false for item in items)
super().__init__(line, column)
+ def __hash__(self) -> int:
+ return hash(frozenset(self.items))
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, UnionType):
+ return NotImplemented
+ return frozenset(self.items) == frozenset(other.items)
+
@staticmethod
def make_union(items: List[Type], line: int = -1, column: int = -1) -> Type:
if len(items) > 1:
@@ -1258,6 +1357,14 @@ def make_normalized(item: Type, *, line: int = -1, column: int = -1) -> Type:
def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_type_type(self)
+ def __hash__(self) -> int:
+ return hash(self.item)
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, TypeType):
+ return NotImplemented
+ return self.item == other.item
+
def serialize(self) -> JsonDict:
return {'.class': 'TypeType', 'item': self.item.serialize()}
diff --git a/test-data/unit/check-abstract.test b/test-data/unit/check-abstract.test
index 9a0d4afc5971..cd8cbe43b595 100644
--- a/test-data/unit/check-abstract.test
+++ b/test-data/unit/check-abstract.test
@@ -174,8 +174,8 @@ def f(cls: Type[A]) -> A:
def g() -> A:
return A() # E: Cannot instantiate abstract class 'A' with abstract attribute 'm'
-f(A) # E: Only non-abstract class can be given where 'Type[__main__.A]' is expected
-f(B) # E: Only non-abstract class can be given where 'Type[__main__.A]' is expected
+f(A) # E: Only concrete class can be given where 'Type[A]' is expected
+f(B) # E: Only concrete class can be given where 'Type[A]' is expected
f(C) # OK
x: Type[B]
f(x) # OK
@@ -200,7 +200,7 @@ Alias = A
GoodAlias = C
Alias() # E: Cannot instantiate abstract class 'A' with abstract attribute 'm'
GoodAlias()
-f(Alias) # E: Only non-abstract class can be given where 'Type[__main__.A]' is expected
+f(Alias) # E: Only concrete class can be given where 'Type[A]' is expected
f(GoodAlias)
[out]
@@ -218,14 +218,14 @@ class C(B):
var: Type[A]
var()
-var = A # E: Can only assign non-abstract classes to a variable of type 'Type[__main__.A]'
-var = B # E: Can only assign non-abstract classes to a variable of type 'Type[__main__.A]'
+var = A # E: Can only assign concrete classes to a variable of type 'Type[A]'
+var = B # E: Can only assign concrete classes to a variable of type 'Type[A]'
var = C # OK
var_old = None # type: Type[A] # Old syntax for variable annotations
var_old()
-var_old = A # E: Can only assign non-abstract classes to a variable of type 'Type[__main__.A]'
-var_old = B # E: Can only assign non-abstract classes to a variable of type 'Type[__main__.A]'
+var_old = A # E: Can only assign concrete classes to a variable of type 'Type[A]'
+var_old = B # E: Can only assign concrete classes to a variable of type 'Type[A]'
var_old = C # OK
[out]
diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test
index 7d7f0fa2ddd6..87a25a8e7858 100644
--- a/test-data/unit/check-classes.test
+++ b/test-data/unit/check-classes.test
@@ -3152,19 +3152,20 @@ def f(TB: Type[B]):
[case testMetaclassIterable]
from typing import Iterable, Iterator
-class BadMeta(type):
+class ImplicitMeta(type):
def __iter__(self) -> Iterator[int]: yield 1
-class Bad(metaclass=BadMeta): pass
+class Implicit(metaclass=ImplicitMeta): pass
-for _ in Bad: pass # E: Iterable expected
+for _ in Implicit: pass
+reveal_type(list(Implicit)) # E: Revealed type is 'builtins.list[builtins.int*]'
-class GoodMeta(type, Iterable[int]):
+class ExplicitMeta(type, Iterable[int]):
def __iter__(self) -> Iterator[int]: yield 1
-class Good(metaclass=GoodMeta): pass
-for _ in Good: pass
-reveal_type(list(Good)) # E: Revealed type is 'builtins.list[builtins.int*]'
+class Explicit(metaclass=ExplicitMeta): pass
+for _ in Explicit: pass
+reveal_type(list(Explicit)) # E: Revealed type is 'builtins.list[builtins.int*]'
[builtins fixtures/list.pyi]
diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test
index 0d962533eabd..4ad8620d53d4 100644
--- a/test-data/unit/check-generics.test
+++ b/test-data/unit/check-generics.test
@@ -1201,7 +1201,7 @@ reveal_type(D[str, int]().c()) # E: Revealed type is 'builtins.str*'
from typing import TypeVar, Generic
T = TypeVar('T')
-class A(Generic[T, T]): # E: Duplicate type variables in Generic[...]
+class A(Generic[T, T]): # E: Duplicate type variables in Generic[...] or Protocol[...]
pass
a = A[int]()
@@ -1218,7 +1218,7 @@ class A(Generic[T]):
class B(Generic[T]):
pass
-class C(A[T], B[S], Generic[T]): # E: If Generic[...] is present it should list all type variables
+class C(A[T], B[S], Generic[T]): # E: If Generic[...] or Protocol[...] is present it should list all type variables
pass
c = C[int, str]()
diff --git a/test-data/unit/check-incremental.test b/test-data/unit/check-incremental.test
index 557d42a1ec42..916a17752c7d 100644
--- a/test-data/unit/check-incremental.test
+++ b/test-data/unit/check-incremental.test
@@ -1501,6 +1501,99 @@ class MyClass:
[rechecked]
[stale]
+[case testIncrementalWorksWithBasicProtocols]
+import a
+[file a.py]
+from b import P
+
+x: int
+y: P[int]
+x = y.meth()
+
+class C:
+ def meth(self) -> int:
+ pass
+y = C()
+
+[file a.py.2]
+from b import P
+
+x: str
+y: P[str]
+x = y.meth()
+
+class C:
+ def meth(self) -> str:
+ pass
+y = C()
+[file b.py]
+from typing import Protocol, TypeVar
+
+T = TypeVar('T', covariant=True)
+class P(Protocol[T]):
+ def meth(self) -> T:
+ pass
+
+[case testIncrementalSwitchFromNominalToStructural]
+import a
+[file a.py]
+from b import B, fun
+class C(B):
+ def x(self) -> int: pass
+ def y(self) -> int: pass
+fun(C())
+
+[file b.py]
+from typing import Protocol
+class B:
+ def x(self) -> float: pass
+def fun(arg: B) -> None:
+ arg.x()
+
+[file b.py.2]
+from typing import Protocol
+class B(Protocol):
+ def x(self) -> float: pass
+def fun(arg: B) -> None:
+ arg.x()
+
+[file a.py.3]
+from b import fun
+class C:
+ def x(self) -> int: pass
+ def y(self) -> int: pass
+fun(C())
+[out1]
+[out2]
+[out3]
+
+[case testIncrementalSwitchFromStructuralToNominal]
+import a
+[file a.py]
+from b import fun
+class C:
+ def x(self) -> int: pass
+ def y(self) -> int: pass
+fun(C())
+
+[file b.py]
+from typing import Protocol
+class B(Protocol):
+ def x(self) -> float: pass
+def fun(arg: B) -> None:
+ arg.x()
+
+[file b.py.2]
+from typing import Protocol
+class B:
+ def x(self) -> float: pass
+def fun(arg: B) -> None:
+ arg.x()
+
+[out1]
+[out2]
+tmp/a.py:5: error: Argument 1 to "fun" has incompatible type "C"; expected "B"
+
[case testIncrementalWorksWithNamedTuple]
import foo
diff --git a/test-data/unit/check-lists.test b/test-data/unit/check-lists.test
index c9c67e80d4fb..c575f4b76da4 100644
--- a/test-data/unit/check-lists.test
+++ b/test-data/unit/check-lists.test
@@ -64,7 +64,7 @@ class C: pass
[case testListWithStarExpr]
(x, *a) = [1, 2, 3]
a = [1, *[2, 3]]
-reveal_type(a) # E: Revealed type is 'builtins.list[builtins.int]'
+reveal_type(a) # E: Revealed type is 'builtins.list[builtins.int*]'
b = [0, *a]
reveal_type(b) # E: Revealed type is 'builtins.list[builtins.int*]'
c = [*a, 0]
diff --git a/test-data/unit/check-newtype.test b/test-data/unit/check-newtype.test
index 4c7e1468a167..c9e01edf3ef4 100644
--- a/test-data/unit/check-newtype.test
+++ b/test-data/unit/check-newtype.test
@@ -332,6 +332,22 @@ B = NewType('B', A)
class C(B): pass # E: Cannot subclass NewType
[out]
+[case testCannotUseNewTypeWithProtocols]
+from typing import Protocol, NewType
+
+class P(Protocol):
+ attr: int
+class D:
+ attr: int
+
+C = NewType('C', P) # E: NewType cannot be used with protocol classes
+
+x: C = C(D()) # We still accept this, treating 'C' as non-protocol subclass.
+reveal_type(x.attr) # E: Revealed type is 'builtins.int'
+x.bad_attr # E: "C" has no attribute "bad_attr"
+C(1) # E: Argument 1 to "C" has incompatible type "int"; expected "P"
+[out]
+
[case testNewTypeAny]
from typing import NewType
Any = NewType('Any', int)
diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test
new file mode 100644
index 000000000000..ba190988be61
--- /dev/null
+++ b/test-data/unit/check-protocols.test
@@ -0,0 +1,2120 @@
+-- Simple protocol types
+-- ---------------------
+
+[case testCannotInstantiateProtocol]
+from typing import Protocol
+
+class P(Protocol):
+ def meth(self) -> None:
+ pass
+
+P() # E: Cannot instantiate protocol class "P"
+
+[case testSimpleProtocolOneMethod]
+from typing import Protocol
+
+class P(Protocol):
+ def meth(self) -> None:
+ pass
+
+class B: pass
+class C:
+ def meth(self) -> None:
+ pass
+
+x: P
+def fun(x: P) -> None:
+ x.meth()
+ x.meth(x) # E: Too many arguments for "meth" of "P"
+ x.bad # E: "P" has no attribute "bad"
+
+x = C()
+x = B() # E: Incompatible types in assignment (expression has type "B", variable has type "P")
+
+fun(C())
+fun(B()) # E: Argument 1 to "fun" has incompatible type "B"; expected "P"
+
+def fun2() -> P:
+ return C()
+def fun3() -> P:
+ return B() # E: Incompatible return value type (got "B", expected "P")
+
+[case testSimpleProtocolOneAbstractMethod]
+from typing import Protocol
+from abc import abstractmethod
+
+class P(Protocol):
+ @abstractmethod
+ def meth(self) -> None:
+ pass
+
+class B: pass
+class C:
+ def meth(self) -> None:
+ pass
+class D(B):
+ def meth(self) -> None:
+ pass
+
+x: P
+def fun(x: P) -> None:
+ x.meth()
+ x.meth(x) # E: Too many arguments for "meth" of "P"
+ x.bad # E: "P" has no attribute "bad"
+
+x = C()
+x = D()
+x = B() # E: Incompatible types in assignment (expression has type "B", variable has type "P")
+fun(C())
+fun(D())
+fun(B()) # E: Argument 1 to "fun" has incompatible type "B"; expected "P"
+fun(x)
+
+[case testProtocolMethodBodies]
+from typing import Protocol, List
+
+class P(Protocol):
+ def meth(self) -> int:
+ return 'no way' # E: Incompatible return value type (got "str", expected "int")
+
+# explicit ellipsis is OK in protocol methods
+class P2(Protocol):
+ def meth2(self) -> List[int]:
+ ...
+[builtins fixtures/list.pyi]
+
+[case testSimpleProtocolOneMethodOverride]
+from typing import Protocol, Union
+
+class P(Protocol):
+ def meth(self) -> Union[int, str]:
+ pass
+class SubP(P, Protocol):
+ def meth(self) -> int:
+ pass
+
+class B: pass
+class C:
+ def meth(self) -> int:
+ pass
+z: P
+x: SubP
+def fun(x: SubP) -> str:
+ x.bad # E: "SubP" has no attribute "bad"
+ return x.meth() # E: Incompatible return value type (got "int", expected "str")
+
+z = x
+x = C()
+x = B() # E: Incompatible types in assignment (expression has type "B", variable has type "SubP")
+
+reveal_type(fun(C())) # E: Revealed type is 'builtins.str'
+fun(B()) # E: Argument 1 to "fun" has incompatible type "B"; expected "SubP"
+
+[case testSimpleProtocolTwoMethodsMerge]
+from typing import Protocol
+
+class P1(Protocol):
+ def meth1(self) -> int:
+ pass
+class P2(Protocol):
+ def meth2(self) -> str:
+ pass
+class P(P1, P2, Protocol): pass
+
+class B: pass
+class C1:
+ def meth1(self) -> int:
+ pass
+class C2(C1):
+ def meth2(self) -> str:
+ pass
+class C:
+ def meth1(self) -> int:
+ pass
+ def meth2(self) -> str:
+ pass
+
+class AnotherP(Protocol):
+ def meth1(self) -> int:
+ pass
+ def meth2(self) -> str:
+ pass
+
+x: P
+reveal_type(x.meth1()) # E: Revealed type is 'builtins.int'
+reveal_type(x.meth2()) # E: Revealed type is 'builtins.str'
+
+c: C
+c1: C1
+c2: C2
+y: AnotherP
+
+x = c
+x = B() # E: Incompatible types in assignment (expression has type "B", variable has type "P")
+x = c1 # E: Incompatible types in assignment (expression has type "C1", variable has type "P") \
+ # N: 'C1' is missing following 'P' protocol member: \
+ # N: meth2
+x = c2
+x = y
+y = x
+
+[case testSimpleProtocolTwoMethodsExtend]
+from typing import Protocol
+
+class P1(Protocol):
+ def meth1(self) -> int:
+ pass
+class P2(P1, Protocol):
+ def meth2(self) -> str:
+ pass
+
+class Cbad:
+ def meth1(self) -> int:
+ pass
+
+class C:
+ def meth1(self) -> int:
+ pass
+ def meth2(self) -> str:
+ pass
+
+x: P2
+reveal_type(x.meth1()) # E: Revealed type is 'builtins.int'
+reveal_type(x.meth2()) # E: Revealed type is 'builtins.str'
+
+x = C() # OK
+x = Cbad() # E: Incompatible types in assignment (expression has type "Cbad", variable has type "P2") \
+ # N: 'Cbad' is missing following 'P2' protocol member: \
+ # N: meth2
+
+[case testProtocolMethodVsAttributeErrors]
+from typing import Protocol
+
+class P(Protocol):
+ def meth(self) -> int:
+ pass
+class C:
+ meth: int
+x: P = C() # E: Incompatible types in assignment (expression has type "C", variable has type "P") \
+ # N: Following member(s) of "C" have conflicts: \
+ # N: meth: expected Callable[[], int], got "int"
+
+[case testProtocolMethodVsAttributeErrors2]
+from typing import Protocol
+
+class P(Protocol):
+ @property
+ def meth(self) -> int:
+ pass
+class C:
+ def meth(self) -> int:
+ pass
+x: P = C() # E: Incompatible types in assignment (expression has type "C", variable has type "P") \
+ # N: Following member(s) of "C" have conflicts: \
+ # N: meth: expected "int", got Callable[[], int]
+[builtins fixtures/property.pyi]
+
+[case testCannotAssignNormalToProtocol]
+from typing import Protocol
+
+class P(Protocol):
+ def meth(self) -> int:
+ pass
+class C:
+ def meth(self) -> int:
+ pass
+
+x: C
+y: P
+x = y # E: Incompatible types in assignment (expression has type "P", variable has type "C")
+
+[case testIndependentProtocolSubtyping]
+from typing import Protocol
+
+class P1(Protocol):
+ def meth(self) -> int:
+ pass
+class P2(Protocol):
+ def meth(self) -> int:
+ pass
+
+x1: P1
+x2: P2
+
+x1 = x2
+x2 = x1
+
+def f1(x: P1) -> None: pass
+def f2(x: P2) -> None: pass
+
+f1(x2)
+f2(x1)
+
+[case testNoneDisablesProtocolImplementation]
+from typing import Protocol
+
+class MyHashable(Protocol):
+ def __my_hash__(self) -> int:
+ return 0
+
+class C:
+ __my_hash__ = None
+
+var: MyHashable = C() # E: Incompatible types in assignment (expression has type "C", variable has type "MyHashable")
+
+[case testNoneDisablesProtocolSubclassingWithStrictOptional]
+# flags: --strict-optional
+from typing import Protocol
+
+class MyHashable(Protocol):
+ def __my_hash__(self) -> int:
+ return 0
+
+class C(MyHashable):
+ __my_hash__ = None # E: Incompatible types in assignment \
+(expression has type None, base class "MyHashable" defined the type as Callable[[MyHashable], int])
+
+[case testProtocolsWithNoneAndStrictOptional]
+# flags: --strict-optional
+from typing import Protocol
+class P(Protocol):
+ x = 0 # type: int
+
+class C:
+ x = None
+
+x: P = C() # Error!
+def f(x: P) -> None: pass
+f(C()) # Error!
+[out]
+main:9: error: Incompatible types in assignment (expression has type "C", variable has type "P")
+main:9: note: Following member(s) of "C" have conflicts:
+main:9: note: x: expected "int", got None
+main:11: error: Argument 1 to "f" has incompatible type "C"; expected "P"
+main:11: note: Following member(s) of "C" have conflicts:
+main:11: note: x: expected "int", got None
+
+-- Semanal errors in protocol types
+-- --------------------------------
+
+[case testBasicSemanalErrorsInProtocols]
+from typing import Protocol, Generic, TypeVar, Iterable
+
+T = TypeVar('T', covariant=True)
+S = TypeVar('S', covariant=True)
+
+class P1(Protocol[T, T]): # E: Duplicate type variables in Generic[...] or Protocol[...]
+ def meth(self) -> T:
+ pass
+
+class P2(Protocol[T], Protocol[S]): # E: Only single Generic[...] or Protocol[...] can be in bases
+ def meth(self) -> T:
+ pass
+
+class P3(Protocol[T], Generic[S]): # E: Only single Generic[...] or Protocol[...] can be in bases
+ def meth(self) -> T:
+ pass
+
+class P4(Protocol[T]):
+ attr: Iterable[S] # E: Invalid type "__main__.S"
+
+class P5(Iterable[S], Protocol[T]): # E: If Generic[...] or Protocol[...] is present it should list all type variables
+ def meth(self) -> T:
+ pass
+
+[case testProhibitSelfDefinitionInProtocols]
+from typing import Protocol
+
+class P(Protocol):
+ def __init__(self, a: int) -> None:
+ self.a = a # E: Protocol members cannot be defined via assignment to self \
+ # E: "P" has no attribute "a"
+
+class B: pass
+class C:
+ def __init__(self, a: int) -> None:
+ pass
+
+x: P
+x = B()
+# The above has an incompatible __init__, but mypy ignores this for nominal subtypes?
+x = C(1)
+
+class P2(Protocol):
+ a: int
+ def __init__(self) -> None:
+ self.a = 1
+
+class B2(P2):
+ a: int
+
+x2: P2 = B2() # OK
+
+[case testProtocolAndRuntimeAreDefinedAlsoInTypingExtensions]
+from typing_extensions import Protocol, runtime
+
+@runtime
+class P(Protocol):
+ def meth(self) -> int:
+ pass
+
+x: object
+if isinstance(x, P):
+ reveal_type(x) # E: Revealed type is '__main__.P'
+ reveal_type(x.meth()) # E: Revealed type is 'builtins.int'
+
+class C:
+ def meth(self) -> int:
+ pass
+
+z: P = C()
+[builtins fixtures/dict.pyi]
+
+[case testProtocolsCannotInheritFromNormal]
+from typing import Protocol
+
+class C: pass
+class D: pass
+
+class P(C, Protocol): # E: All bases of a protocol must be protocols
+ attr: int
+
+class P2(P, D, Protocol): # E: All bases of a protocol must be protocols
+ pass
+
+P2() # E: Cannot instantiate abstract class 'P2' with abstract attribute 'attr'
+p: P2
+reveal_type(p.attr) # E: Revealed type is 'builtins.int'
+
+-- Generic protocol types
+-- ----------------------
+
+[case testGenericMethodWithProtocol]
+from typing import Protocol, TypeVar
+T = TypeVar('T')
+
+class P(Protocol):
+ def meth(self, x: int) -> int:
+ return x
+class C:
+ def meth(self, x: T) -> T:
+ return x
+
+x: P = C()
+
+[case testGenericMethodWithProtocol2]
+from typing import Protocol, TypeVar
+T = TypeVar('T')
+
+class P(Protocol):
+ def meth(self, x: T) -> T:
+ return x
+class C:
+ def meth(self, x: int) -> int:
+ return x
+
+x: P = C()
+[out]
+main:11: error: Incompatible types in assignment (expression has type "C", variable has type "P")
+main:11: note: Following member(s) of "C" have conflicts:
+main:11: note: Expected:
+main:11: note: def [T] meth(self, x: T) -> T
+main:11: note: Got:
+main:11: note: def meth(self, x: int) -> int
+
+[case testAutomaticProtocolVariance]
+from typing import TypeVar, Protocol
+
+T = TypeVar('T')
+
+# In case of these errors we proceed with declared variance.
+class Pco(Protocol[T]): # E: Invariant type variable 'T' used in protocol where covariant one is expected
+ def meth(self) -> T:
+ pass
+class Pcontra(Protocol[T]): # E: Invariant type variable 'T' used in protocol where contravariant one is expected
+ def meth(self, x: T) -> None:
+ pass
+class Pinv(Protocol[T]):
+ attr: T
+
+class A: pass
+class B(A): pass
+
+x1: Pco[B]
+y1: Pco[A]
+x1 = y1 # E: Incompatible types in assignment (expression has type Pco[A], variable has type Pco[B])
+y1 = x1 # E: Incompatible types in assignment (expression has type Pco[B], variable has type Pco[A])
+
+x2: Pcontra[B]
+y2: Pcontra[A]
+y2 = x2 # E: Incompatible types in assignment (expression has type Pcontra[B], variable has type Pcontra[A])
+x2 = y2 # E: Incompatible types in assignment (expression has type Pcontra[A], variable has type Pcontra[B])
+
+x3: Pinv[B]
+y3: Pinv[A]
+y3 = x3 # E: Incompatible types in assignment (expression has type Pinv[B], variable has type Pinv[A])
+x3 = y3 # E: Incompatible types in assignment (expression has type Pinv[A], variable has type Pinv[B])
+
+[case testProtocolVarianceWithCallableAndList]
+from typing import Protocol, TypeVar, Callable, List
+T = TypeVar('T')
+S = TypeVar('S')
+T_co = TypeVar('T_co', covariant=True)
+
+class P(Protocol[T, S]): # E: Invariant type variable 'T' used in protocol where covariant one is expected \
+ # E: Invariant type variable 'S' used in protocol where contravariant one is expected
+ def fun(self, callback: Callable[[T], S]) -> None: pass
+
+class P2(Protocol[T_co]): # E: Covariant type variable 'T_co' used in protocol where invariant one is expected
+ lst: List[T_co]
+[builtins fixtures/list.pyi]
+
+[case testProtocolVarianceWithUnusedVariable]
+from typing import Protocol, TypeVar
+T = TypeVar('T')
+
+class P(Protocol[T]): # E: Invariant type variable 'T' used in protocol where covariant one is expected
+ attr: int
+
+[case testGenericProtocolsInference1]
+from typing import Protocol, Sequence, TypeVar
+
+T = TypeVar('T', covariant=True)
+
+class Closeable(Protocol[T]):
+ def close(self) -> T:
+ pass
+
+class F:
+ def close(self) -> int:
+ return 0
+
+def close(arg: Closeable[T]) -> T:
+ return arg.close()
+
+def close_all(args: Sequence[Closeable[T]]) -> T:
+ for arg in args:
+ arg.close()
+ return args[0].close()
+
+arg: Closeable[int]
+
+reveal_type(close(F())) # E: Revealed type is 'builtins.int*'
+reveal_type(close(arg)) # E: Revealed type is 'builtins.int*'
+reveal_type(close_all([F()])) # E: Revealed type is 'builtins.int*'
+reveal_type(close_all([arg])) # E: Revealed type is 'builtins.int*'
+[builtins fixtures/isinstancelist.pyi]
+
+[case testProtocolGenericInference2]
+from typing import Generic, TypeVar, Protocol
+T = TypeVar('T')
+S = TypeVar('S')
+
+class P(Protocol[T, S]):
+ x: T
+ y: S
+
+class C:
+ x: int
+ y: int
+
+def fun3(x: P[T, T]) -> T:
+ pass
+reveal_type(fun3(C())) # E: Revealed type is 'builtins.int*'
+
+[case testProtocolGenericInferenceCovariant]
+from typing import Generic, TypeVar, Protocol
+T = TypeVar('T', covariant=True)
+S = TypeVar('S', covariant=True)
+U = TypeVar('U')
+
+class P(Protocol[T, S]):
+ def x(self) -> T: pass
+ def y(self) -> S: pass
+
+class C:
+ def x(self) -> int: pass
+ def y(self) -> int: pass
+
+def fun4(x: U, y: P[U, U]) -> U:
+ pass
+reveal_type(fun4('a', C())) # E: Revealed type is 'builtins.object*'
+
+[case testUnrealtedGenericProtolsEquivalent]
+from typing import TypeVar, Protocol
+T = TypeVar('T')
+
+class PA(Protocol[T]):
+ attr: int
+ def meth(self) -> T: pass
+ def other(self, arg: T) -> None: pass
+class PB(Protocol[T]): # exactly the same as above
+ attr: int
+ def meth(self) -> T: pass
+ def other(self, arg: T) -> None: pass
+
+def fun(x: PA[T]) -> PA[T]:
+ y: PB[T] = x
+ z: PB[T]
+ return z
+
+x: PA
+y: PB
+x = y
+y = x
+
+xi: PA[int]
+yi: PB[int]
+xi = yi
+yi = xi
+
+[case testGenericSubProtocols]
+from typing import TypeVar, Protocol, Tuple, Generic
+
+T = TypeVar('T')
+S = TypeVar('S')
+
+class P1(Protocol[T]):
+ attr1: T
+class P2(P1[T], Protocol[T, S]):
+ attr2: Tuple[T, S]
+
+class C:
+ def __init__(self, a1: int, a2: Tuple[int, int]) -> None:
+ self.attr1 = a1
+ self.attr2 = a2
+
+c: C
+var: P2[int, int] = c
+var2: P2[int, str] = c # E: Incompatible types in assignment (expression has type "C", variable has type P2[int, str]) \
+ # N: Following member(s) of "C" have conflicts: \
+ # N: attr2: expected "Tuple[int, str]", got "Tuple[int, int]"
+
+class D(Generic[T]):
+ attr1: T
+class E(D[T]):
+ attr2: Tuple[T, T]
+
+def f(x: T) -> T:
+ z: P2[T, T] = E[T]()
+ y: P2[T, T] = D[T]() # E: Incompatible types in assignment (expression has type D[T], variable has type P2[T, T]) \
+ # N: 'D' is missing following 'P2' protocol member: \
+ # N: attr2
+ return x
+[builtins fixtures/isinstancelist.pyi]
+
+[case testGenericSubProtocolsExtensionInvariant]
+from typing import TypeVar, Protocol, Union
+
+T = TypeVar('T')
+S = TypeVar('S')
+
+class P1(Protocol[T]):
+ attr1: T
+class P2(Protocol[T]):
+ attr2: T
+class P(P1[T], P2[S], Protocol):
+ pass
+
+class C:
+ attr1: int
+ attr2: str
+
+class A:
+ attr1: A
+class B:
+ attr2: B
+class D(A, B): pass
+
+x: P = D() # Same as P[Any, Any]
+
+var: P[Union[int, P], Union[P, str]] = C() # E: Incompatible types in assignment (expression has type "C", variable has type P[Union[int, P[Any, Any]], Union[P[Any, Any], str]]) \
+ # N: Following member(s) of "C" have conflicts: \
+ # N: attr1: expected "Union[int, P[Any, Any]]", got "int" \
+ # N: attr2: expected "Union[P[Any, Any], str]", got "str"
+
+[case testGenericSubProtocolsExtensionCovariant]
+from typing import TypeVar, Protocol, Union
+
+T = TypeVar('T', covariant=True)
+S = TypeVar('S', covariant=True)
+
+class P1(Protocol[T]):
+ def attr1(self) -> T: pass
+class P2(Protocol[T]):
+ def attr2(self) -> T: pass
+class P(P1[T], P2[S], Protocol):
+ pass
+
+class C:
+ def attr1(self) -> int: pass
+ def attr2(self) -> str: pass
+
+var: P[Union[int, P], Union[P, str]] = C() # OK for covariant
+var2: P[Union[str, P], Union[P, int]] = C()
+[out]
+main:18: error: Incompatible types in assignment (expression has type "C", variable has type P[Union[str, P[Any, Any]], Union[P[Any, Any], int]])
+main:18: note: Following member(s) of "C" have conflicts:
+main:18: note: Expected:
+main:18: note: def attr1(self) -> Union[str, P[Any, Any]]
+main:18: note: Got:
+main:18: note: def attr1(self) -> int
+main:18: note: Expected:
+main:18: note: def attr2(self) -> Union[P[Any, Any], int]
+main:18: note: Got:
+main:18: note: def attr2(self) -> str
+
+[case testSelfTypesWithProtocolsBehaveAsWithNominal]
+from typing import Protocol, TypeVar
+
+T = TypeVar('T', bound=Shape)
+class Shape(Protocol):
+ def combine(self: T, other: T) -> T:
+ pass
+
+class NonProtoShape:
+ def combine(self: T, other: T) -> T:
+ pass
+class Circle:
+ def combine(self: T, other: Shape) -> T:
+ pass
+class Triangle:
+ def combine(self, other: Shape) -> Shape:
+ pass
+class Bad:
+ def combine(self, other: int) -> str:
+ pass
+
+def f(s: Shape) -> None: pass
+s: Shape
+
+f(NonProtoShape())
+f(Circle())
+s = Triangle()
+s = Bad()
+
+n2: NonProtoShape = s
+[out]
+main:26: error: Incompatible types in assignment (expression has type "Triangle", variable has type "Shape")
+main:26: note: Following member(s) of "Triangle" have conflicts:
+main:26: note: Expected:
+main:26: note: def combine(self, other: Triangle) -> Triangle
+main:26: note: Got:
+main:26: note: def combine(self, other: Shape) -> Shape
+main:27: error: Incompatible types in assignment (expression has type "Bad", variable has type "Shape")
+main:27: note: Following member(s) of "Bad" have conflicts:
+main:27: note: Expected:
+main:27: note: def combine(self, other: Bad) -> Bad
+main:27: note: Got:
+main:27: note: def combine(self, other: int) -> str
+main:29: error: Incompatible types in assignment (expression has type "Shape", variable has type "NonProtoShape")
+
+[case testBadVarianceInProtocols]
+from typing import Protocol, TypeVar
+
+T_co = TypeVar('T_co', covariant=True)
+T_contra = TypeVar('T_contra', contravariant=True)
+
+class Proto(Protocol[T_co, T_contra]): # type: ignore
+ def one(self, x: T_co) -> None: # E: Cannot use a covariant type variable as a parameter
+ pass
+ def other(self) -> T_contra: # E: Cannot use a contravariant type variable as return type
+ pass
+
+# Check that we respect user overrides of variance after the errors are reported
+x: Proto[int, float]
+y: Proto[float, int]
+y = x # OK
+[builtins fixtures/list.pyi]
+
+[case testSubtleBadVarianceInProtocols]
+from typing import Protocol, TypeVar, Iterable, Sequence
+
+T_co = TypeVar('T_co', covariant=True)
+T_contra = TypeVar('T_contra', contravariant=True)
+
+class Proto(Protocol[T_co, T_contra]): # E: Covariant type variable 'T_co' used in protocol where contravariant one is expected \
+ # E: Contravariant type variable 'T_contra' used in protocol where covariant one is expected
+ def one(self, x: Iterable[T_co]) -> None:
+ pass
+ def other(self) -> Sequence[T_contra]:
+ pass
+
+# Check that we respect user overrides of variance after the errors are reported
+x: Proto[int, float]
+y: Proto[float, int]
+y = x # OK
+[builtins fixtures/list.pyi]
+
+-- Recursive protocol types
+-- ------------------------
+
+[case testRecursiveProtocols1]
+from typing import Protocol, Sequence, List, Generic, TypeVar
+
+T = TypeVar('T')
+
+class Traversable(Protocol):
+ @property
+ def leaves(self) -> Sequence[Traversable]: pass
+
+class C: pass
+
+class D(Generic[T]):
+ leaves: List[D[T]]
+
+t: Traversable
+t = D[int]() # OK
+t = C() # E: Incompatible types in assignment (expression has type "C", variable has type "Traversable")
+[builtins fixtures/list.pyi]
+
+[case testRecursiveProtocols2]
+from typing import Protocol, TypeVar
+
+T = TypeVar('T')
+class Linked(Protocol[T]):
+ val: T
+ def next(self) -> Linked[T]: pass
+
+class L:
+ val: int
+ def next(self) -> L: pass
+
+def last(seq: Linked[T]) -> T:
+ pass
+
+reveal_type(last(L())) # E: Revealed type is 'builtins.int*'
+[builtins fixtures/list.pyi]
+
+[case testRecursiveProtocolSubtleMismatch]
+from typing import Protocol, TypeVar
+
+T = TypeVar('T')
+class Linked(Protocol[T]):
+ val: T
+ def next(self) -> Linked[T]: pass
+class L:
+ val: int
+ def next(self) -> int: pass
+
+def last(seq: Linked[T]) -> T:
+ pass
+last(L()) # E: Argument 1 to "last" has incompatible type "L"; expected Linked[]
+
+[case testMutuallyRecursiveProtocols]
+from typing import Protocol, Sequence, List
+
+class P1(Protocol):
+ @property
+ def attr1(self) -> Sequence[P2]: pass
+class P2(Protocol):
+ @property
+ def attr2(self) -> Sequence[P1]: pass
+
+class C: pass
+class A:
+ attr1: List[B]
+class B:
+ attr2: List[A]
+
+t: P1
+t = A() # OK
+t = B() # E: Incompatible types in assignment (expression has type "B", variable has type "P1")
+t = C() # E: Incompatible types in assignment (expression has type "C", variable has type "P1")
+[builtins fixtures/list.pyi]
+
+[case testMutuallyRecursiveProtocolsTypesWithSubteMismatch]
+from typing import Protocol, Sequence, List
+
+class P1(Protocol):
+ @property
+ def attr1(self) -> Sequence[P2]: pass
+class P2(Protocol):
+ @property
+ def attr2(self) -> Sequence[P1]: pass
+
+class C: pass
+class A:
+ attr1: List[B]
+class B:
+ attr2: List[C]
+
+t: P1
+t = A() # E: Incompatible types in assignment (expression has type "A", variable has type "P1") \
+ # N: Following member(s) of "A" have conflicts: \
+ # N: attr1: expected Sequence[P2], got List[B]
+[builtins fixtures/list.pyi]
+
+[case testMutuallyRecursiveProtocolsTypesWithSubteMismatchWriteable]
+from typing import Protocol
+
+class P1(Protocol):
+ @property
+ def attr1(self) -> P2: pass
+class P2(Protocol):
+ attr2: P1
+
+class A:
+ attr1: B
+class B:
+ attr2: A
+
+x: P1 = A() # E: Incompatible types in assignment (expression has type "A", variable has type "P1") \
+ # N: Following member(s) of "A" have conflicts: \
+ # N: attr1: expected "P2", got "B"
+[builtins fixtures/property.pyi]
+
+-- FIXME: things like this should work
+[case testWeirdRecursiveInferenceForProtocols-skip]
+from typing import Protocol, TypeVar, Generic
+T_co = TypeVar('T_co', covariant=True)
+T = TypeVar('T')
+
+class P(Protocol[T_co]):
+ def meth(self) -> P[T_co]: pass
+
+class C(Generic[T]):
+ def meth(self) -> C[T]: pass
+
+x: C[int]
+def f(arg: P[T]) -> T: pass
+reveal_type(f(x)) #E: Revealed type is 'builtins.int*'
+
+-- @property, @classmethod and @staticmethod in protocol types
+-- -----------------------------------------------------------
+
+[case testCannotInstantiateAbstractMethodExplicitProtocolSubtypes]
+from typing import Protocol
+from abc import abstractmethod
+
+class P(Protocol):
+ @abstractmethod
+ def meth(self) -> int:
+ pass
+
+class A(P):
+ pass
+
+A() # E: Cannot instantiate abstract class 'A' with abstract attribute 'meth'
+
+class C(A):
+ def meth(self) -> int:
+ pass
+class C2(P):
+ def meth(self) -> int:
+ pass
+
+C()
+C2()
+
+[case testCannotInstantiateAbstractVariableExplicitProtocolSubtypes]
+from typing import Protocol
+
+class P(Protocol):
+ attr: int
+
+class A(P):
+ pass
+
+A() # E: Cannot instantiate abstract class 'A' with abstract attribute 'attr'
+
+class C(A):
+ attr: int
+class C2(P):
+ def __init__(self) -> None:
+ self.attr = 1
+
+C()
+C2()
+
+class P2(Protocol):
+ attr: int = 1
+
+class B(P2): pass
+B() # OK, attr is not abstract
+
+[case testClassVarsInProtocols]
+from typing import Protocol, ClassVar
+
+class PInst(Protocol):
+ v: int
+
+class PClass(Protocol):
+ v: ClassVar[int]
+
+class CInst:
+ v: int
+
+class CClass:
+ v: ClassVar[int]
+
+x: PInst
+y: PClass
+
+x = CInst()
+x = CClass() # E: Incompatible types in assignment (expression has type "CClass", variable has type "PInst") \
+ # N: Protocol member PInst.v expected instance variable, got class variable
+y = CClass()
+y = CInst() # E: Incompatible types in assignment (expression has type "CInst", variable has type "PClass") \
+ # N: Protocol member PClass.v expected class variable, got instance variable
+
+[case testPropertyInProtocols]
+from typing import Protocol
+
+class PP(Protocol):
+ @property
+ def attr(self) -> int:
+ pass
+
+class P(Protocol):
+ attr: int
+
+x: P
+y: PP
+y = x
+
+x2: P
+y2: PP
+x2 = y2 # E: Incompatible types in assignment (expression has type "PP", variable has type "P") \
+ # N: Protocol member P.attr expected settable variable, got read-only attribute
+[builtins fixtures/property.pyi]
+
+[case testSettablePropertyInProtocols]
+from typing import Protocol
+
+class PPS(Protocol):
+ @property
+ def attr(self) -> int:
+ pass
+ @attr.setter
+ def attr(self, x: int) -> None:
+ pass
+
+class PP(Protocol):
+ @property
+ def attr(self) -> int:
+ pass
+
+class P(Protocol):
+ attr: int
+
+x: P
+z: PPS
+z = x
+
+x2: P
+z2: PPS
+x2 = z2
+
+y3: PP
+z3: PPS
+y3 = z3
+
+y4: PP
+z4: PPS
+z4 = y4 # E: Incompatible types in assignment (expression has type "PP", variable has type "PPS") \
+ # N: Protocol member PPS.attr expected settable variable, got read-only attribute
+[builtins fixtures/property.pyi]
+
+[case testStaticAndClassMethodsInProtocols]
+from typing import Protocol, Type, TypeVar
+
+class P(Protocol):
+ def meth(self, x: int) -> str:
+ pass
+
+class PC(Protocol):
+ @classmethod
+ def meth(cls, x: int) -> str:
+ pass
+
+class B:
+ @staticmethod
+ def meth(x: int) -> str:
+ pass
+
+class C:
+ def meth(self, x: int) -> str:
+ pass
+
+x: P
+x = C()
+x = B()
+
+y: PC
+y = B()
+y = C() # E: Incompatible types in assignment (expression has type "C", variable has type "PC") \
+ # N: Protocol member PC.meth expected class or static method
+[builtins fixtures/classmethod.pyi]
+
+[case testOverloadedMethodsInProtocols]
+from typing import overload, Protocol, Union
+
+class P(Protocol):
+ @overload
+ def f(self, x: int) -> int: pass
+ @overload
+ def f(self, x: str) -> str: pass
+
+class C:
+ def f(self, x: Union[int, str]) -> None:
+ pass
+class D:
+ def f(self, x: int) -> None:
+ pass
+
+x: P = C()
+x = D()
+[out]
+main:17: error: Incompatible types in assignment (expression has type "D", variable has type "P")
+main:17: note: Following member(s) of "D" have conflicts:
+main:17: note: Expected:
+main:17: note: @overload
+main:17: note: def f(self, x: int) -> int
+main:17: note: @overload
+main:17: note: def f(self, x: str) -> str
+main:17: note: Got:
+main:17: note: def f(self, x: int) -> None
+
+[case testCannotInstantiateProtocolWithOverloadedUnimplementedMethod]
+from typing import overload, Protocol
+
+class P(Protocol):
+ @overload
+ def meth(self, x: int) -> int: pass
+ @overload
+ def meth(self, x: str) -> bytes: pass
+class C(P):
+ pass
+C() # E: Cannot instantiate abstract class 'C' with abstract attribute 'meth'
+
+[case testCanUseOverloadedImplementationsInProtocols]
+from typing import overload, Protocol, Union
+class P(Protocol):
+ @overload
+ def meth(self, x: int) -> int: pass
+ @overload
+ def meth(self, x: str) -> bool: pass
+ def meth(self, x: Union[int, str]):
+ if isinstance(x, int):
+ return x
+ return True
+
+class C(P):
+ pass
+x = C()
+reveal_type(x.meth('hi')) # E: Revealed type is 'builtins.bool'
+[builtins fixtures/isinstance.pyi]
+
+[case testProtocolsWithIdenticalOverloads]
+from typing import overload, Protocol
+
+class PA(Protocol):
+ @overload
+ def meth(self, x: int) -> int: pass
+ @overload
+ def meth(self, x: str) -> bytes: pass
+class PB(Protocol): # identical to above
+ @overload
+ def meth(self, x: int) -> int: pass
+ @overload
+ def meth(self, x: str) -> bytes: pass
+
+x: PA
+y: PB
+x = y
+def fun(arg: PB) -> None: pass
+fun(x)
+
+[case testProtocolsWithIncompatibleOverloads]
+from typing import overload, Protocol
+
+class PA(Protocol):
+ @overload
+ def meth(self, x: int) -> int: pass
+ @overload
+ def meth(self, x: str) -> bytes: pass
+class PB(Protocol):
+ @overload
+ def meth(self, x: int) -> int: pass
+ @overload
+ def meth(self, x: bytes) -> str: pass
+
+x: PA
+y: PB
+x = y
+[out]
+main:16: error: Incompatible types in assignment (expression has type "PB", variable has type "PA")
+main:16: note: Following member(s) of "PB" have conflicts:
+main:16: note: Expected:
+main:16: note: @overload
+main:16: note: def meth(self, x: int) -> int
+main:16: note: @overload
+main:16: note: def meth(self, x: str) -> bytes
+main:16: note: Got:
+main:16: note: @overload
+main:16: note: def meth(self, x: int) -> int
+main:16: note: @overload
+main:16: note: def meth(self, x: bytes) -> str
+
+-- Join and meet with protocol types
+-- ---------------------------------
+
+[case testJoinProtocolWithProtocol]
+from typing import Protocol
+
+class P(Protocol):
+ attr: int
+class P2(Protocol):
+ attr: int
+ attr2: str
+
+x: P
+y: P2
+
+l0 = [x, x]
+l1 = [y, y]
+l = [x, y]
+reveal_type(l0) # E: Revealed type is 'builtins.list[__main__.P*]'
+reveal_type(l1) # E: Revealed type is 'builtins.list[__main__.P2*]'
+reveal_type(l) # E: Revealed type is 'builtins.list[__main__.P*]'
+[builtins fixtures/list.pyi]
+
+[case testJoinOfIncompatibleProtocols]
+from typing import Protocol
+
+class P(Protocol):
+ attr: int
+class P2(Protocol):
+ attr2: str
+
+x: P
+y: P2
+reveal_type([x, y]) # E: Revealed type is 'builtins.list[builtins.object*]'
+[builtins fixtures/list.pyi]
+
+[case testJoinProtocolWithNormal]
+from typing import Protocol
+
+class P(Protocol):
+ attr: int
+
+class C:
+ attr: int
+
+x: P
+y: C
+
+l = [x, y]
+
+reveal_type(l) # E: Revealed type is 'builtins.list[__main__.P*]'
+[builtins fixtures/list.pyi]
+
+[case testMeetProtocolWithProtocol]
+from typing import Protocol, Callable, TypeVar
+
+class P(Protocol):
+ attr: int
+class P2(Protocol):
+ attr: int
+ attr2: str
+
+T = TypeVar('T')
+def f(x: Callable[[T, T], None]) -> T: pass
+def g(x: P, y: P2) -> None: pass
+reveal_type(f(g)) # E: Revealed type is '__main__.P2*'
+
+[case testMeetOfIncompatibleProtocols]
+from typing import Protocol, Callable, TypeVar
+
+class P(Protocol):
+ attr: int
+class P2(Protocol):
+ attr2: str
+
+T = TypeVar('T')
+def f(x: Callable[[T, T], None]) -> T: pass
+def g(x: P, y: P2) -> None: pass
+x = f(g) # E: "f" does not return a value
+
+[case testMeetProtocolWithNormal]
+from typing import Protocol, Callable, TypeVar
+
+class P(Protocol):
+ attr: int
+class C:
+ attr: int
+
+T = TypeVar('T')
+def f(x: Callable[[T, T], None]) -> T: pass
+def g(x: P, y: C) -> None: pass
+reveal_type(f(g)) # E: Revealed type is '__main__.C*'
+
+[case testInferProtocolFromProtocol]
+from typing import Protocol, Sequence, TypeVar, Generic
+
+T = TypeVar('T')
+class Box(Protocol[T]):
+ content: T
+class Linked(Protocol[T]):
+ val: T
+ def next(self) -> Linked[T]: pass
+
+class L(Generic[T]):
+ val: Box[T]
+ def next(self) -> L[T]: pass
+
+def last(seq: Linked[T]) -> T:
+ pass
+
+reveal_type(last(L[int]())) # E: Revealed type is '__main__.Box*[builtins.int*]'
+reveal_type(last(L[str]()).content) # E: Revealed type is 'builtins.str*'
+
+[case testOverloadOnProtocol]
+from typing import overload, Protocol, runtime
+
+@runtime
+class P1(Protocol):
+ attr1: int
+class P2(Protocol):
+ attr2: str
+
+class C1:
+ attr1: int
+class C2:
+ attr2: str
+class C: pass
+
+@overload
+def f(x: P1) -> int: ...
+@overload
+def f(x: P2) -> str: ...
+def f(x):
+ if isinstance(x, P1):
+ return P1.attr1
+ if isinstance(x, P2): # E: Only @runtime protocols can be used with instance and class checks
+ return P1.attr2
+
+reveal_type(f(C1())) # E: Revealed type is 'builtins.int'
+reveal_type(f(C2())) # E: Revealed type is 'builtins.str'
+class D(C1, C2): pass # Compatible with both P1 and P2
+# FIXME: the below is not right, see #1322
+reveal_type(f(D())) # E: Revealed type is 'Any'
+f(C()) # E: No overload variant of "f" matches argument types [__main__.C]
+[builtins fixtures/isinstance.pyi]
+
+-- Unions of protocol types
+-- ------------------------
+
+[case testBasicUnionsOfProtocols]
+from typing import Union, Protocol
+
+class P1(Protocol):
+ attr1: int
+class P2(Protocol):
+ attr2: int
+
+class C1:
+ attr1: int
+class C2:
+ attr2: int
+class C(C1, C2):
+ pass
+
+class B: ...
+
+x: Union[P1, P2]
+
+x = C1()
+x = C2()
+x = C()
+x = B() # E: Incompatible types in assignment (expression has type "B", variable has type "Union[P1, P2]")
+
+[case testUnionsOfNormalClassesWithProtocols]
+from typing import Protocol, Union
+
+class P1(Protocol):
+ attr1: int
+class P2(Protocol):
+ attr2: int
+
+class C1:
+ attr1: int
+class C2:
+ attr2: int
+class C(C1, C2):
+ pass
+
+class D1:
+ attr1: int
+
+def f1(x: P1) -> None:
+ pass
+def f2(x: P2) -> None:
+ pass
+
+x: Union[C1, C2]
+y: Union[C1, D1]
+z: Union[C, D1]
+
+f1(x) # E: Argument 1 to "f1" has incompatible type "Union[C1, C2]"; expected "P1"
+f1(y)
+f1(z)
+f2(x) # E: Argument 1 to "f2" has incompatible type "Union[C1, C2]"; expected "P2"
+f2(z) # E: Argument 1 to "f2" has incompatible type "Union[C, D1]"; expected "P2"
+
+-- Type[] with protocol types
+-- --------------------------
+
+[case testInstantiationProtocolInTypeForFunctions]
+from typing import Type, Protocol
+
+class P(Protocol):
+ def m(self) -> None: pass
+class P1(Protocol):
+ def m(self) -> None: pass
+class Pbad(Protocol):
+ def mbad(self) -> int: pass
+class B(P): pass
+class C:
+ def m(self) -> None:
+ pass
+
+def f(cls: Type[P]) -> P:
+ return cls() # OK
+def g() -> P:
+ return P() # E: Cannot instantiate protocol class "P"
+
+f(P) # E: Only concrete class can be given where 'Type[P]' is expected
+f(B) # OK
+f(C) # OK
+x: Type[P1]
+xbad: Type[Pbad]
+f(x) # OK
+f(xbad) # E: Argument 1 to "f" has incompatible type Type[Pbad]; expected Type[P]
+
+[case testInstantiationProtocolInTypeForAliases]
+from typing import Type, Protocol
+
+class P(Protocol):
+ def m(self) -> None: pass
+class C:
+ def m(self) -> None:
+ pass
+
+def f(cls: Type[P]) -> P:
+ return cls() # OK
+
+Alias = P
+GoodAlias = C
+Alias() # E: Cannot instantiate protocol class "P"
+GoodAlias()
+f(Alias) # E: Only concrete class can be given where 'Type[P]' is expected
+f(GoodAlias)
+
+[case testInstantiationProtocolInTypeForVariables]
+from typing import Type, Protocol
+
+class P(Protocol):
+ def m(self) -> None: pass
+class B(P): pass
+class C:
+ def m(self) -> None:
+ pass
+
+var: Type[P]
+var()
+var = P # E: Can only assign concrete classes to a variable of type 'Type[P]'
+var = B # OK
+var = C # OK
+
+var_old = None # type: Type[P] # Old syntax for variable annotations
+var_old()
+var_old = P # E: Can only assign concrete classes to a variable of type 'Type[P]'
+var_old = B # OK
+var_old = C # OK
+
+[case testInstantiationProtocolInTypeForClassMethods]
+from typing import Type, Protocol
+
+class Logger:
+ @staticmethod
+ def log(a: Type[C]):
+ pass
+class C(Protocol):
+ @classmethod
+ def action(cls) -> None:
+ cls() #OK for classmethods
+ Logger.log(cls) #OK for classmethods
+[builtins fixtures/classmethod.pyi]
+
+-- isinstance() with @runtime protocols
+-- ------------------------------------
+
+[case testSimpleRuntimeProtocolCheck]
+from typing import Protocol, runtime
+
+@runtime # E: @runtime can only be used with protocol classes
+class C:
+ pass
+
+class P(Protocol):
+ def meth(self) -> None:
+ pass
+
+@runtime
+class R(Protocol):
+ def meth(self) -> int:
+ pass
+
+x: object
+
+if isinstance(x, P): # E: Only @runtime protocols can be used with instance and class checks
+ reveal_type(x) # E: Revealed type is '__main__.P'
+
+if isinstance(x, R):
+ reveal_type(x) # E: Revealed type is '__main__.R'
+ reveal_type(x.meth()) # E: Revealed type is 'builtins.int'
+[builtins fixtures/isinstance.pyi]
+
+[case testRuntimeIterableProtocolCheck]
+from typing import Iterable, List, Union
+
+x: Union[int, List[str]]
+
+if isinstance(x, Iterable):
+ reveal_type(x) # E: Revealed type is 'builtins.list[builtins.str]'
+[builtins fixtures/isinstancelist.pyi]
+
+[case testConcreteClassesInProtocolsIsInstance]
+from typing import Protocol, runtime, TypeVar, Generic
+
+T = TypeVar('T')
+
+@runtime
+class P1(Protocol):
+ def meth1(self) -> int:
+ pass
+@runtime
+class P2(Protocol):
+ def meth2(self) -> int:
+ pass
+@runtime
+class P(P1, P2, Protocol):
+ pass
+
+class C1(Generic[T]):
+ def meth1(self) -> T:
+ pass
+class C2:
+ def meth2(self) -> int:
+ pass
+class C(C1[int], C2): pass
+
+c = C()
+if isinstance(c, P1):
+ reveal_type(c) # E: Revealed type is '__main__.C'
+else:
+ reveal_type(c) # Unreachable
+if isinstance(c, P):
+ reveal_type(c) # E: Revealed type is '__main__.C'
+else:
+ reveal_type(c) # Unreachable
+
+c1i: C1[int]
+if isinstance(c1i, P1):
+ reveal_type(c1i) # E: Revealed type is '__main__.C1[builtins.int]'
+else:
+ reveal_type(c1i) # Unreachable
+if isinstance(c1i, P):
+ reveal_type(c1i) # Unreachable
+else:
+ reveal_type(c1i) # E: Revealed type is '__main__.C1[builtins.int]'
+
+c1s: C1[str]
+if isinstance(c1s, P1):
+ reveal_type(c1s) # Unreachable
+else:
+ reveal_type(c1s) # E: Revealed type is '__main__.C1[builtins.str]'
+
+c2: C2
+if isinstance(c2, P):
+ reveal_type(c2) # Unreachable
+else:
+ reveal_type(c2) # E: Revealed type is '__main__.C2'
+
+[builtins fixtures/isinstancelist.pyi]
+
+[case testConcreteClassesUnionInProtocolsIsInstance]
+from typing import Protocol, runtime, TypeVar, Generic, Union
+
+T = TypeVar('T')
+
+@runtime
+class P1(Protocol):
+ def meth1(self) -> int:
+ pass
+@runtime
+class P2(Protocol):
+ def meth2(self) -> int:
+ pass
+
+class C1(Generic[T]):
+ def meth1(self) -> T:
+ pass
+class C2:
+ def meth2(self) -> int:
+ pass
+
+x: Union[C1[int], C2]
+if isinstance(x, P1):
+ reveal_type(x) # E: Revealed type is '__main__.C1[builtins.int]'
+else:
+ reveal_type(x) # E: Revealed type is '__main__.C2'
+
+if isinstance(x, P2):
+ reveal_type(x) # E: Revealed type is '__main__.C2'
+else:
+ reveal_type(x) # E: Revealed type is '__main__.C1[builtins.int]'
+[builtins fixtures/isinstancelist.pyi]
+
+-- Non-Instances and protocol types (Callable vs __call__ etc.)
+-- ------------------------------------------------------------
+
+[case testBasicTupleStructuralSubtyping]
+from typing import Tuple, TypeVar, Protocol
+
+T = TypeVar('T', covariant=True)
+
+class MyProto(Protocol[T]):
+ def __len__(self) -> T:
+ pass
+
+t: Tuple[int, str]
+def f(x: MyProto[int]) -> None:
+ pass
+f(t) # OK
+
+y: MyProto[str]
+y = t # E: Incompatible types in assignment (expression has type "Tuple[int, str]", variable has type MyProto[str])
+[builtins fixtures/isinstancelist.pyi]
+
+[case testBasicNamedTupleStructuralSubtyping]
+from typing import NamedTuple, TypeVar, Protocol
+
+T = TypeVar('T', covariant=True)
+S = TypeVar('S', covariant=True)
+
+class P(Protocol[T, S]):
+ @property
+ def x(self) -> T: pass
+ @property
+ def y(self) -> S: pass
+
+class N(NamedTuple):
+ x: int
+ y: str
+class N2(NamedTuple):
+ x: int
+class N3(NamedTuple):
+ x: int
+ y: int
+
+z: N
+z3: N3
+
+def fun(x: P[int, str]) -> None:
+ pass
+def fun2(x: P[int, int]) -> None:
+ pass
+def fun3(x: P[T, T]) -> T:
+ return x.x
+
+fun(z)
+fun2(z) # E: Argument 1 to "fun2" has incompatible type "N"; expected P[int, int] \
+ # N: Following member(s) of "N" have conflicts: \
+ # N: y: expected "int", got "str"
+
+fun(N2(1)) # E: Argument 1 to "fun" has incompatible type "N2"; expected P[int, str] \
+ # N: 'N2' is missing following 'P' protocol member: \
+ # N: y
+
+reveal_type(fun3(z)) # E: Revealed type is 'builtins.object*'
+
+reveal_type(fun3(z3)) # E: Revealed type is 'builtins.int*'
+[builtins fixtures/list.pyi]
+
+[case testBasicCallableStructuralSubtyping]
+from typing import Callable, Generic, TypeVar
+
+def apply(f: Callable[[int], int], x: int) -> int:
+ return f(x)
+
+class Add5:
+ def __call__(self, x: int) -> int:
+ return x + 5
+
+apply(Add5(), 5)
+
+T = TypeVar('T')
+def apply_gen(f: Callable[[T], T]) -> T:
+ pass
+
+reveal_type(apply_gen(Add5())) # E: Revealed type is 'builtins.int*'
+def apply_str(f: Callable[[str], int], x: str) -> int:
+ return f(x)
+apply_str(Add5(), 'a') # E: Argument 1 to "apply_str" has incompatible type "Add5"; expected Callable[[str], int] \
+ # N: 'Add5.__call__' has type 'Callable[[Arg(int, 'x')], int]'
+[builtins fixtures/isinstancelist.pyi]
+
+[case testMoreComplexCallableStructuralSubtyping]
+from mypy_extensions import Arg, VarArg
+from typing import Protocol, Callable
+
+def call_soon(cb: Callable[[Arg(int, 'x'), VarArg(str)], int]): pass
+
+class Good:
+ def __call__(self, x: int, *rest: str) -> int: pass
+class Bad1:
+ def __call__(self, x: int, *rest: int) -> int: pass
+class Bad2:
+ def __call__(self, y: int, *rest: str) -> int: pass
+call_soon(Good())
+call_soon(Bad1()) # E: Argument 1 to "call_soon" has incompatible type "Bad1"; expected Callable[[int, VarArg(str)], int] \
+ # N: 'Bad1.__call__' has type 'Callable[[Arg(int, 'x'), VarArg(int)], int]'
+call_soon(Bad2()) # E: Argument 1 to "call_soon" has incompatible type "Bad2"; expected Callable[[int, VarArg(str)], int] \
+ # N: 'Bad2.__call__' has type 'Callable[[Arg(int, 'y'), VarArg(str)], int]'
+[builtins fixtures/isinstancelist.pyi]
+
+[case testStructuralSupportForPartial]
+from typing import Callable, TypeVar, Generic, Any
+
+T = TypeVar('T')
+
+class partial(Generic[T]):
+ def __init__(self, func: Callable[..., T], *args: Any) -> None: ...
+ def __call__(self, *args: Any) -> T: ...
+
+def inc(a: int, temp: str) -> int:
+ pass
+
+def foo(f: Callable[[int], T]) -> T:
+ return f(1)
+
+reveal_type(foo(partial(inc, 'temp'))) # E: Revealed type is 'builtins.int*'
+[builtins fixtures/list.pyi]
+
+[case testStructuralInferenceForCallable]
+from typing import Callable, TypeVar, Tuple
+
+T = TypeVar('T')
+S = TypeVar('S')
+
+class Actual:
+ def __call__(self, arg: int) -> str: pass
+
+def fun(cb: Callable[[T], S]) -> Tuple[T, S]: pass
+reveal_type(fun(Actual())) # E: Revealed type is 'Tuple[builtins.int*, builtins.str*]'
+[builtins fixtures/tuple.pyi]
+
+-- Standard protocol types (SupportsInt, Sized, etc.)
+-- --------------------------------------------------
+
+-- More tests could be added for types from typing converted to protocols
+
+[case testBasicSizedProtocol]
+from typing import Sized
+
+class Foo:
+ def __len__(self) -> int:
+ return 42
+
+def bar(a: Sized) -> int:
+ return a.__len__()
+
+bar(Foo())
+bar((1, 2))
+bar(1) # E: Argument 1 to "bar" has incompatible type "int"; expected "Sized"
+
+[builtins fixtures/isinstancelist.pyi]
+
+[case testBasicSupportsIntProtocol]
+from typing import SupportsInt
+
+class Bar:
+ def __int__(self):
+ return 1
+
+def foo(a: SupportsInt):
+ pass
+
+foo(Bar())
+foo('no way') # E: Argument 1 to "foo" has incompatible type "str"; expected "SupportsInt"
+
+[builtins fixtures/isinstancelist.pyi]
+
+-- Additional tests and corner cases for protocols
+-- ----------------------------------------------
+
+[case testAnyWithProtocols]
+from typing import Protocol, Any, TypeVar
+
+T = TypeVar('T')
+
+class P1(Protocol):
+ attr1: int
+class P2(Protocol[T]):
+ attr2: T
+class P3(Protocol):
+ attr: P3
+
+def f1(x: P1) -> None: pass
+def f2(x: P2[str]) -> None: pass
+def f3(x: P3) -> None: pass
+
+class C1:
+ attr1: Any
+class C2:
+ attr2: Any
+class C3:
+ attr: Any
+
+f1(C1())
+f2(C2())
+f3(C3())
+
+f2(C3()) # E: Argument 1 to "f2" has incompatible type "C3"; expected P2[str]
+a: Any
+f1(a)
+f2(a)
+f3(a)
+
+[case testErrorsForProtocolsInDifferentPlaces]
+from typing import Protocol
+
+class P(Protocol):
+ attr1: int
+ attr2: str
+ attr3: int
+
+class C:
+ attr1: str
+ @property
+ def attr2(self) -> int: pass
+
+x: P = C() # E: Incompatible types in assignment (expression has type "C", variable has type "P") \
+ # N: 'C' is missing following 'P' protocol member: \
+ # N: attr3 \
+ # N: Following member(s) of "C" have conflicts: \
+ # N: attr1: expected "int", got "str" \
+ # N: attr2: expected "str", got "int" \
+ # N: Protocol member P.attr2 expected settable variable, got read-only attribute
+
+def f(x: P) -> P:
+ return C() # E: Incompatible return value type (got "C", expected "P") \
+ # N: 'C' is missing following 'P' protocol member: \
+ # N: attr3 \
+ # N: Following member(s) of "C" have conflicts: \
+ # N: attr1: expected "int", got "str" \
+ # N: attr2: expected "str", got "int" \
+ # N: Protocol member P.attr2 expected settable variable, got read-only attribute
+
+f(C()) # E: Argument 1 to "f" has incompatible type "C"; expected "P" \
+ # N: 'C' is missing following 'P' protocol member: \
+ # N: attr3 \
+ # N: Following member(s) of "C" have conflicts: \
+ # N: attr1: expected "int", got "str" \
+ # N: attr2: expected "str", got "int" \
+ # N: Protocol member P.attr2 expected settable variable, got read-only attribute
+[builtins fixtures/list.pyi]
+
+[case testIterableProtocolOnClass]
+from typing import TypeVar, Iterator
+T = TypeVar('T', bound='A')
+
+class A:
+ def __iter__(self: T) -> Iterator[T]: pass
+
+class B(A): pass
+
+reveal_type(list(b for b in B())) # E: Revealed type is 'builtins.list[__main__.B*]'
+reveal_type(list(B())) # E: Revealed type is 'builtins.list[__main__.B*]'
+[builtins fixtures/list.pyi]
+
+[case testIterableProtocolOnMetaclass]
+from typing import TypeVar, Iterator, Type
+T = TypeVar('T')
+
+class EMeta(type):
+ def __iter__(self: Type[T]) -> Iterator[T]: pass
+
+class E(metaclass=EMeta):
+ pass
+
+class C(E):
+ pass
+
+reveal_type(list(c for c in C)) # E: Revealed type is 'builtins.list[__main__.C*]'
+reveal_type(list(C)) # E: Revealed type is 'builtins.list[__main__.C*]'
+[builtins fixtures/list.pyi]
+
+[case testClassesGetattrWithProtocols]
+from typing import Protocol
+
+class P(Protocol):
+ attr: int
+
+class PP(Protocol):
+ @property
+ def attr(self) -> int:
+ pass
+
+class C:
+ def __getattr__(self, attr: str) -> int:
+ pass
+class C2(C):
+ def __setattr__(self, attr: str, val: int) -> None:
+ pass
+
+class D:
+ def __getattr__(self, attr: str) -> str:
+ pass
+
+def fun(x: P) -> None:
+ reveal_type(P.attr) # E: Revealed type is 'builtins.int'
+def fun_p(x: PP) -> None:
+ reveal_type(P.attr) # E: Revealed type is 'builtins.int'
+
+fun(C()) # E: Argument 1 to "fun" has incompatible type "C"; expected "P" \
+ # N: Protocol member P.attr expected settable variable, got read-only attribute
+fun(C2())
+fun_p(D()) # E: Argument 1 to "fun_p" has incompatible type "D"; expected "PP" \
+ # N: Following member(s) of "D" have conflicts: \
+ # N: attr: expected "int", got "str"
+fun_p(C()) # OK
+[builtins fixtures/list.pyi]
+
+[case testImplicitTypesInProtocols]
+from typing import Protocol
+
+class P(Protocol):
+ x = 1 # E: All protocol members must have explicitly declared types
+
+class C:
+ x: int
+
+class D:
+ x: str
+
+x: P
+x = D() # E: Incompatible types in assignment (expression has type "D", variable has type "P") \
+ # N: Following member(s) of "D" have conflicts: \
+ # N: x: expected "int", got "str"
+x = C() # OK
+[builtins fixtures/list.pyi]
+
+[case testProtocolIncompatibilityWithGenericMethod]
+from typing import Protocol, TypeVar
+
+T = TypeVar('T')
+S = TypeVar('S')
+
+class A(Protocol):
+ def f(self, x: T) -> None: pass
+class B:
+ def f(self, x: S, y: T) -> None: pass
+
+x: A = B()
+[out]
+main:11: error: Incompatible types in assignment (expression has type "B", variable has type "A")
+main:11: note: Following member(s) of "B" have conflicts:
+main:11: note: Expected:
+main:11: note: def [T] f(self, x: T) -> None
+main:11: note: Got:
+main:11: note: def [S, T] f(self, x: S, y: T) -> None
+
+[case testProtocolIncompatibilityWithGenericMethodBounded]
+from typing import Protocol, TypeVar
+
+T = TypeVar('T')
+S = TypeVar('S', bound=int)
+
+class A(Protocol):
+ def f(self, x: T) -> None: pass
+class B:
+ def f(self, x: S, y: T) -> None: pass
+
+x: A = B()
+[out]
+main:11: error: Incompatible types in assignment (expression has type "B", variable has type "A")
+main:11: note: Following member(s) of "B" have conflicts:
+main:11: note: Expected:
+main:11: note: def [T] f(self, x: T) -> None
+main:11: note: Got:
+main:11: note: def [S <: int, T] f(self, x: S, y: T) -> None
+
+[case testProtocolIncompatibilityWithGenericRestricted]
+from typing import Protocol, TypeVar
+
+T = TypeVar('T')
+S = TypeVar('S', int, str)
+
+class A(Protocol):
+ def f(self, x: T) -> None: pass
+class B:
+ def f(self, x: S, y: T) -> None: pass
+
+x: A = B()
+[out]
+main:11: error: Incompatible types in assignment (expression has type "B", variable has type "A")
+main:11: note: Following member(s) of "B" have conflicts:
+main:11: note: Expected:
+main:11: note: def [T] f(self, x: T) -> None
+main:11: note: Got:
+main:11: note: def [S in (int, str), T] f(self, x: S, y: T) -> None
+
+[case testProtocolIncompatibilityWithManyOverloads]
+from typing import Protocol, overload
+
+class C1: pass
+class C2: pass
+class A(Protocol):
+ @overload
+ def f(self, x: int) -> int: pass
+ @overload
+ def f(self, x: str) -> str: pass
+ @overload
+ def f(self, x: C1) -> C2: pass
+ @overload
+ def f(self, x: C2) -> C1: pass
+
+class B:
+ def f(self) -> None: pass
+
+x: A = B()
+[out]
+main:18: error: Incompatible types in assignment (expression has type "B", variable has type "A")
+main:18: note: Following member(s) of "B" have conflicts:
+main:18: note: Expected:
+main:18: note: @overload
+main:18: note: def f(self, x: int) -> int
+main:18: note: @overload
+main:18: note: def f(self, x: str) -> str
+main:18: note: <2 more overload(s) not shown>
+main:18: note: Got:
+main:18: note: def f(self) -> None
+
+[case testProtocolIncompatibilityWithManyConflicts]
+from typing import Protocol
+
+class A(Protocol):
+ def f(self, x: int) -> None: pass
+ def g(self, x: int) -> None: pass
+ def h(self, x: int) -> None: pass
+ def i(self, x: int) -> None: pass
+class B:
+ def f(self, x: str) -> None: pass
+ def g(self, x: str) -> None: pass
+ def h(self, x: str) -> None: pass
+ def i(self, x: str) -> None: pass
+
+x: A = B()
+[out]
+main:14: error: Incompatible types in assignment (expression has type "B", variable has type "A")
+main:14: note: Following member(s) of "B" have conflicts:
+main:14: note: Expected:
+main:14: note: def f(self, x: int) -> None
+main:14: note: Got:
+main:14: note: def f(self, x: str) -> None
+main:14: note: Expected:
+main:14: note: def g(self, x: int) -> None
+main:14: note: Got:
+main:14: note: def g(self, x: str) -> None
+main:14: note: <2 more conflict(s) not shown>
+
+[case testDontShowNotesForTupleAndIterableProtocol]
+from typing import Iterable, Sequence, Protocol, NamedTuple
+
+class N(NamedTuple):
+ x: int
+
+def f1(x: Iterable[str]) -> None: pass
+def f2(x: Sequence[str]) -> None: pass
+
+# The errors below should be short
+f1(N(1)) # E: Argument 1 to "f1" has incompatible type "N"; expected Iterable[str]
+f2(N(2)) # E: Argument 1 to "f2" has incompatible type "N"; expected Sequence[str]
+[builtins fixtures/tuple.pyi]
+
+[case testNotManyFlagConflitsShownInProtocols]
+from typing import Protocol
+
+class AllSettable(Protocol):
+ a: int
+ b: int
+ c: int
+ d: int
+
+class AllReadOnly:
+ @property
+ def a(self) -> int: pass
+ @property
+ def b(self) -> int: pass
+ @property
+ def c(self) -> int: pass
+ @property
+ def d(self) -> int: pass
+
+x: AllSettable = AllReadOnly()
+[builtins fixtures/property.pyi]
+[out]
+main:19: error: Incompatible types in assignment (expression has type "AllReadOnly", variable has type "AllSettable")
+main:19: note: Protocol member AllSettable.a expected settable variable, got read-only attribute
+main:19: note: Protocol member AllSettable.b expected settable variable, got read-only attribute
+main:19: note: <2 more conflict(s) not shown>
+
+[case testProtocolsMoreConflictsNotShown]
+from typing_extensions import Protocol
+from typing import Generic, TypeVar
+
+T = TypeVar('T')
+
+class MockMapping(Protocol[T]):
+ def a(self, x: T) -> int: pass
+ def b(self, x: T) -> int: pass
+ def c(self, x: T) -> int: pass
+ d: T
+ e: T
+ f: T
+
+class MockDict(MockMapping[T]):
+ more: int
+
+def f(x: MockMapping[int]) -> None: pass
+x: MockDict[str]
+f(x) # E: Argument 1 to "f" has incompatible type MockDict[str]; expected MockMapping[int]
+
+[case testProtocolNotesForComplexSignatures]
+from typing import Protocol, Optional
+
+class P(Protocol):
+ def meth(self, x: int, *args: str) -> None: pass
+ def other(self, *args, hint: Optional[str] = None, **kwargs: str) -> None: pass
+class C:
+ def meth(self) -> int: pass
+ def other(self) -> int: pass
+
+x: P = C()
+[builtins fixtures/dict.pyi]
+[out]
+main:10: error: Incompatible types in assignment (expression has type "C", variable has type "P")
+main:10: note: Following member(s) of "C" have conflicts:
+main:10: note: Expected:
+main:10: note: def meth(self, x: int, *args: str) -> None
+main:10: note: Got:
+main:10: note: def meth(self) -> int
+main:10: note: Expected:
+main:10: note: def other(self, *args: Any, hint: Optional[str] = ..., **kwargs: str) -> None
+main:10: note: Got:
+main:10: note: def other(self) -> int
+
diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test
index 193c7468d442..a13d1e704f41 100644
--- a/test-data/unit/check-typeddict.test
+++ b/test-data/unit/check-typeddict.test
@@ -343,7 +343,9 @@ Point = TypedDict('Point', {'x': int, 'y': int})
def as_dict(p: Point) -> Dict[str, int]:
return p # E: Incompatible return value type (got "Point", expected Dict[str, int])
def as_mutable_mapping(p: Point) -> MutableMapping[str, int]:
- return p # E: Incompatible return value type (got "Point", expected MutableMapping[str, int])
+ return p # E: Incompatible return value type (got "Point", expected MutableMapping[str, int]) \
+ # N: 'Point' is missing following 'MutableMapping' protocol member: \
+ # N: __setitem__
[builtins fixtures/dict.pyi]
[case testCanConvertTypedDictToAny]
@@ -372,6 +374,52 @@ ll = [b, c]
f(ll) # E: Argument 1 to "f" has incompatible type List[TypedDict({'x': int, 'z': str})]; expected "A"
[builtins fixtures/dict.pyi]
+[case testTypedDictWithSimpleProtocol]
+from typing_extensions import Protocol
+from mypy_extensions import TypedDict
+
+class StrIntMap(Protocol):
+ def __getitem__(self, key: str) -> int: ...
+
+A = TypedDict('A', {'x': int, 'y': int})
+B = TypedDict('B', {'x': int, 'y': str})
+
+def fun(arg: StrIntMap) -> None: ...
+a: A
+b: B
+fun(a)
+fun(b) # Error
+[builtins fixtures/dict.pyi]
+[out]
+main:14: error: Argument 1 to "fun" has incompatible type "B"; expected "StrIntMap"
+main:14: note: Following member(s) of "B" have conflicts:
+main:14: note: Expected:
+main:14: note: def __getitem__(self, str) -> int
+main:14: note: Got:
+main:14: note: def __getitem__(self, str) -> object
+
+[case testTypedDictWithSimpleProtocolInference]
+from typing_extensions import Protocol
+from mypy_extensions import TypedDict
+from typing import TypeVar
+
+T_co = TypeVar('T_co', covariant=True)
+T = TypeVar('T')
+
+class StrMap(Protocol[T_co]):
+ def __getitem__(self, key: str) -> T_co: ...
+
+A = TypedDict('A', {'x': int, 'y': int})
+B = TypedDict('B', {'x': int, 'y': str})
+
+def fun(arg: StrMap[T]) -> T:
+ return arg['whatever']
+a: A
+b: B
+reveal_type(fun(a)) # E: Revealed type is 'builtins.int*'
+reveal_type(fun(b)) # E: Revealed type is 'builtins.object*'
+[builtins fixtures/dict.pyi]
+[out]
-- Join
@@ -1132,9 +1180,16 @@ def f(x: int) -> None: ...
def f(x): pass
a: A
-f(a) # E: Argument 1 to "f" has incompatible type "A"; expected Iterable[int]
+f(a)
[builtins fixtures/dict.pyi]
[typing fixtures/typing-full.pyi]
+[out]
+main:13: error: Argument 1 to "f" has incompatible type "A"; expected Iterable[int]
+main:13: note: Following member(s) of "A" have conflicts:
+main:13: note: Expected:
+main:13: note: def __iter__(self) -> Iterator[int]
+main:13: note: Got:
+main:13: note: def __iter__(self) -> Iterator[str]
[case testTypedDictOverloading3]
from typing import overload
diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi
index a271315e4dde..cf8b61f9397a 100644
--- a/test-data/unit/fixtures/dict.pyi
+++ b/test-data/unit/fixtures/dict.pyi
@@ -11,11 +11,12 @@ class object:
class type: pass
-class dict(Mapping[KT, VT], Iterable[KT], Generic[KT, VT]):
+class dict(Generic[KT, VT]):
@overload
def __init__(self, **kwargs: VT) -> None: pass
@overload
def __init__(self, arg: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: pass
+ def __getitem__(self, key: KT) -> VT: pass
def __setitem__(self, k: KT, v: VT) -> None: pass
def __iter__(self) -> Iterator[KT]: pass
def update(self, a: Mapping[KT, VT]) -> None: pass
@@ -23,6 +24,7 @@ class dict(Mapping[KT, VT], Iterable[KT], Generic[KT, VT]):
def get(self, k: KT) -> Optional[VT]: pass
@overload
def get(self, k: KT, default: Union[KT, T]) -> Union[VT, T]: pass
+ def __len__(self) -> int: ...
class int: # for convenience
def __add__(self, x: int) -> int: pass
@@ -30,7 +32,7 @@ class int: # for convenience
class str: pass # for keyword argument key type
class unicode: pass # needed for py2 docstrings
-class list(Iterable[T], Generic[T]): # needed by some test cases
+class list(Generic[T]): # needed by some test cases
def __getitem__(self, x: int) -> T: pass
def __iter__(self) -> Iterator[T]: pass
def __mul__(self, x: int) -> list[T]: pass
@@ -41,4 +43,5 @@ class float: pass
class bool: pass
class ellipsis: pass
+def isinstance(x: object, t: Union[type, Tuple]) -> bool: pass
class BaseException: pass
diff --git a/test-data/unit/fixtures/isinstancelist.pyi b/test-data/unit/fixtures/isinstancelist.pyi
index 930d309aa3db..99aca1befe39 100644
--- a/test-data/unit/fixtures/isinstancelist.pyi
+++ b/test-data/unit/fixtures/isinstancelist.pyi
@@ -23,16 +23,17 @@ T = TypeVar('T')
KT = TypeVar('KT')
VT = TypeVar('VT')
-class tuple(Generic[T]): pass
+class tuple(Generic[T]):
+ def __len__(self) -> int: pass
-class list(Iterable[T]):
+class list(Generic[T]):
def __iter__(self) -> Iterator[T]: pass
def __mul__(self, x: int) -> list[T]: pass
def __setitem__(self, x: int, v: T) -> None: pass
def __getitem__(self, x: int) -> T: pass
def __add__(self, x: List[T]) -> T: pass
-class dict(Iterable[KT], Mapping[KT, VT]):
+class dict(Mapping[KT, VT]):
@overload
def __init__(self, **kwargs: VT) -> None: pass
@overload
@@ -41,7 +42,7 @@ class dict(Iterable[KT], Mapping[KT, VT]):
def __iter__(self) -> Iterator[KT]: pass
def update(self, a: Mapping[KT, VT]) -> None: pass
-class set(Iterable[T]):
+class set(Generic[T]):
def __iter__(self) -> Iterator[T]: pass
def add(self, x: T) -> None: pass
def discard(self, x: T) -> None: pass
diff --git a/test-data/unit/fixtures/list.pyi b/test-data/unit/fixtures/list.pyi
index d5d1000d3364..7b6d1dbd127b 100644
--- a/test-data/unit/fixtures/list.pyi
+++ b/test-data/unit/fixtures/list.pyi
@@ -10,7 +10,7 @@ class object:
class type: pass
class ellipsis: pass
-class list(Iterable[T], Generic[T]):
+class list(Generic[T]):
@overload
def __init__(self) -> None: pass
@overload
diff --git a/test-data/unit/fixtures/typing-full.pyi b/test-data/unit/fixtures/typing-full.pyi
index d43e34065907..62fac70034c0 100644
--- a/test-data/unit/fixtures/typing-full.pyi
+++ b/test-data/unit/fixtures/typing-full.pyi
@@ -17,6 +17,7 @@ Union = 0
Optional = 0
TypeVar = 0
Generic = 0
+Protocol = 0
Tuple = 0
Callable = 0
_promote = 0
@@ -33,33 +34,42 @@ Dict = 0
Set = 0
T = TypeVar('T')
+T_co = TypeVar('T_co', covariant=True)
+T_contra = TypeVar('T_contra', contravariant=True)
U = TypeVar('U')
V = TypeVar('V')
S = TypeVar('S')
-class Container(Generic[T]):
+# Note: definitions below are different from typeshed, variances are declared
+# to silence the protocol variance checks. Maybe it is better to use type: ignore?
+
+@runtime
+class Container(Protocol[T_contra]):
@abstractmethod
# Use int because bool isn't in the default test builtins
- def __contains__(self, arg: T) -> int: pass
+ def __contains__(self, arg: T_contra) -> int: pass
-class Sized:
+@runtime
+class Sized(Protocol):
@abstractmethod
def __len__(self) -> int: pass
-class Iterable(Generic[T]):
+@runtime
+class Iterable(Protocol[T_co]):
@abstractmethod
- def __iter__(self) -> 'Iterator[T]': pass
+ def __iter__(self) -> 'Iterator[T_co]': pass
-class Iterator(Iterable[T], Generic[T]):
+@runtime
+class Iterator(Iterable[T_co], Protocol):
@abstractmethod
- def __next__(self) -> T: pass
+ def __next__(self) -> T_co: pass
class Generator(Iterator[T], Generic[T, U, V]):
@abstractmethod
def send(self, value: U) -> T: pass
@abstractmethod
- def throw(self, typ: Any, val: Any=None, tb=None) -> None: pass
+ def throw(self, typ: Any, val: Any=None, tb: Any=None) -> None: pass
@abstractmethod
def close(self) -> None: pass
@@ -83,38 +93,52 @@ class AsyncGenerator(AsyncIterator[T], Generic[T, U]):
@abstractmethod
def __aiter__(self) -> 'AsyncGenerator[T, U]': pass
-class Awaitable(Generic[T]):
+@runtime
+class Awaitable(Protocol[T]):
@abstractmethod
def __await__(self) -> Generator[Any, Any, T]: pass
class AwaitableGenerator(Generator[T, U, V], Awaitable[V], Generic[T, U, V, S]):
pass
-class AsyncIterable(Generic[T]):
+@runtime
+class AsyncIterable(Protocol[T]):
@abstractmethod
def __aiter__(self) -> 'AsyncIterator[T]': pass
-class AsyncIterator(AsyncIterable[T], Generic[T]):
+@runtime
+class AsyncIterator(AsyncIterable[T], Protocol):
def __aiter__(self) -> 'AsyncIterator[T]': return self
@abstractmethod
def __anext__(self) -> Awaitable[T]: pass
-class Sequence(Iterable[T], Generic[T]):
+@runtime
+class Sequence(Iterable[T_co], Protocol):
@abstractmethod
- def __getitem__(self, n: Any) -> T: pass
+ def __getitem__(self, n: Any) -> T_co: pass
-class Mapping(Iterable[T], Sized, Generic[T, U]):
+@runtime
+class Mapping(Iterable[T], Protocol[T, T_co]):
+ def __getitem__(self, key: T) -> T_co: pass
@overload
- def get(self, k: T) -> Optional[U]: ...
+ def get(self, k: T) -> Optional[T_co]: pass
@overload
- def get(self, k: T, default: Union[U, V]) -> Union[U, V]: ...
- def values(self) -> Iterable[U]: pass # Approximate return type
+ def get(self, k: T, default: Union[T_co, V]) -> Union[T_co, V]: pass
+ def values(self) -> Iterable[T_co]: pass # Approximate return type
def __len__(self) -> int: ...
-class MutableMapping(Mapping[T, U]): pass
+@runtime
+class MutableMapping(Mapping[T, U], Protocol):
+ def __setitem__(self, k: T, v: U) -> None: pass
+
+class SupportsInt(Protocol):
+ def __int__(self) -> int: pass
+
+def runtime(cls: T) -> T:
+ return cls
class ContextManager(Generic[T]):
- def __enter__(self) -> T: ...
- def __exit__(self, exc_type, exc_value, traceback): ...
+ def __enter__(self) -> T: pass
+ def __exit__(self, exc_type, exc_value, traceback): pass
TYPE_CHECKING = 1
diff --git a/test-data/unit/lib-stub/builtins.pyi b/test-data/unit/lib-stub/builtins.pyi
index 457bea0e9020..87b50f532376 100644
--- a/test-data/unit/lib-stub/builtins.pyi
+++ b/test-data/unit/lib-stub/builtins.pyi
@@ -15,7 +15,6 @@ class bytes: pass
class tuple: pass
class function: pass
-
class ellipsis: pass
# Definition of None is implicit
diff --git a/test-data/unit/lib-stub/typing.pyi b/test-data/unit/lib-stub/typing.pyi
index 02412c77a7b8..8be4abca389a 100644
--- a/test-data/unit/lib-stub/typing.pyi
+++ b/test-data/unit/lib-stub/typing.pyi
@@ -12,6 +12,7 @@ Union = 0
Optional = 0
TypeVar = 0
Generic = 0
+Protocol = 0 # This is not yet defined in typeshed, see PR typeshed/#1220
Tuple = 0
Callable = 0
_promote = 0
@@ -28,37 +29,57 @@ Dict = 0
Set = 0
T = TypeVar('T')
+T_co = TypeVar('T_co', covariant=True)
+T_contra = TypeVar('T_contra', contravariant=True)
U = TypeVar('U')
V = TypeVar('V')
S = TypeVar('S')
-class Container(Generic[T]):
+# Note: definitions below are different from typeshed, variances are declared
+# to silence the protocol variance checks. Maybe it is better to use type: ignore?
+
+@runtime
+class Container(Protocol[T_contra]):
@abstractmethod
# Use int because bool isn't in the default test builtins
- def __contains__(self, arg: T) -> int: pass
+ def __contains__(self, arg: T_contra) -> int: pass
-class Sized:
+@runtime
+class Sized(Protocol):
@abstractmethod
def __len__(self) -> int: pass
-class Iterable(Generic[T]):
+@runtime
+class Iterable(Protocol[T_co]):
@abstractmethod
- def __iter__(self) -> 'Iterator[T]': pass
+ def __iter__(self) -> 'Iterator[T_co]': pass
-class Iterator(Iterable[T], Generic[T]):
+@runtime
+class Iterator(Iterable[T_co], Protocol):
@abstractmethod
- def __next__(self) -> T: pass
+ def __next__(self) -> T_co: pass
class Generator(Iterator[T], Generic[T, U, V]):
@abstractmethod
def __iter__(self) -> 'Generator[T, U, V]': pass
-class Sequence(Iterable[T], Generic[T]):
+@runtime
+class Sequence(Iterable[T_co], Protocol):
@abstractmethod
- def __getitem__(self, n: Any) -> T: pass
+ def __getitem__(self, n: Any) -> T_co: pass
+
+@runtime
+class Mapping(Protocol[T_contra, T_co]):
+ def __getitem__(self, key: T_contra) -> T_co: pass
+
+@runtime
+class MutableMapping(Mapping[T_contra, U], Protocol):
+ def __setitem__(self, k: T_contra, v: U) -> None: pass
-class Mapping(Generic[T, U]): pass
+class SupportsInt(Protocol):
+ def __int__(self) -> int: pass
-class MutableMapping(Generic[T, U]): pass
+def runtime(cls: T) -> T:
+ return cls
TYPE_CHECKING = 1
diff --git a/test-data/unit/lib-stub/typing_extensions.pyi b/test-data/unit/lib-stub/typing_extensions.pyi
new file mode 100644
index 000000000000..8c5be8f3637f
--- /dev/null
+++ b/test-data/unit/lib-stub/typing_extensions.pyi
@@ -0,0 +1,6 @@
+from typing import TypeVar
+
+_T = TypeVar('_T')
+
+class Protocol: pass
+def runtime(x: _T) -> _T: pass
diff --git a/test-data/unit/semanal-errors.test b/test-data/unit/semanal-errors.test
index 2192bce3221d..0dba8f2836bf 100644
--- a/test-data/unit/semanal-errors.test
+++ b/test-data/unit/semanal-errors.test
@@ -953,7 +953,7 @@ from typing import Generic, TypeVar
T = TypeVar('T')
S = TypeVar('S')
class A(Generic[T], Generic[S]): pass \
- # E: Duplicate Generic in bases
+ # E: Only single Generic[...] or Protocol[...] can be in bases
[out]
[case testInvalidMetaclass]