Skip to content

Fix inference of protocol against overloaded function #12227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,8 +658,12 @@ def infer_against_any(self, types: Iterable[Type], any_type: AnyType) -> List[Co
return res

def visit_overloaded(self, template: Overloaded) -> List[Constraint]:
if isinstance(self.actual, CallableType):
items = find_matching_overload_items(template, self.actual)
else:
items = template.items
res: List[Constraint] = []
for t in template.items:
for t in items:
res.extend(infer_constraints(t, self.actual, self.direction))
return res

Expand Down Expand Up @@ -701,3 +705,22 @@ def find_matching_overload_item(overloaded: Overloaded, template: CallableType)
# Fall back to the first item if we can't find a match. This is totally arbitrary --
# maybe we should just bail out at this point.
return items[0]


def find_matching_overload_items(overloaded: Overloaded,
template: CallableType) -> List[CallableType]:
"""Like find_matching_overload_item, but return all matches, not just the first."""
items = overloaded.items
res = []
for item in items:
# Return type may be indeterminate in the template, so ignore it when performing a
# subtype check.
if mypy.subtypes.is_callable_compatible(item, template,
is_compat=mypy.subtypes.is_subtype,
ignore_return=True):
res.append(item)
if not res:
# Falling back to all items if we can't find a match is pretty arbitrary, but
# it maintains backward compatibility.
res = items[:]
return res
83 changes: 83 additions & 0 deletions test-data/unit/check-protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -2806,3 +2806,86 @@ class MyClass:
assert isinstance(self, MyProtocol)
[builtins fixtures/isinstance.pyi]
[typing fixtures/typing-full.pyi]

[case testMatchProtocolAgainstOverloadWithAmbiguity]
from typing import TypeVar, Protocol, Union, Generic, overload

T = TypeVar("T", covariant=True)

class slice: pass

class GetItem(Protocol[T]):
def __getitem__(self, k: int) -> T: ...

class Str: # Resembles 'str'
def __getitem__(self, k: Union[int, slice]) -> Str: ...

class Lst(Generic[T]): # Resembles 'list'
def __init__(self, x: T): ...
@overload
def __getitem__(self, k: int) -> T: ...
@overload
def __getitem__(self, k: slice) -> Lst[T]: ...
def __getitem__(self, k): pass

def f(x: GetItem[GetItem[Str]]) -> None: ...

a: Lst[Str]
f(Lst(a))

class Lst2(Generic[T]):
def __init__(self, x: T): ...
# The overload items are tweaked but still compatible
@overload
def __getitem__(self, k: Str) -> None: ...
@overload
def __getitem__(self, k: slice) -> Lst2[T]: ...
@overload
def __getitem__(self, k: Union[int, str]) -> T: ...
def __getitem__(self, k): pass

b: Lst2[Str]
f(Lst2(b))

class Lst3(Generic[T]): # Resembles 'list'
def __init__(self, x: T): ...
# The overload items are no longer compatible (too narrow argument type)
@overload
def __getitem__(self, k: slice) -> Lst3[T]: ...
@overload
def __getitem__(self, k: bool) -> T: ...
def __getitem__(self, k): pass

c: Lst3[Str]
f(Lst3(c)) # E: Argument 1 to "f" has incompatible type "Lst3[Lst3[Str]]"; expected "GetItem[GetItem[Str]]" \
# N: Following member(s) of "Lst3[Lst3[Str]]" have conflicts: \
# N: Expected: \
# N: def __getitem__(self, int) -> GetItem[Str] \
# N: Got: \
# N: @overload \
# N: def __getitem__(self, slice) -> Lst3[Lst3[Str]] \
# N: @overload \
# N: def __getitem__(self, bool) -> Lst3[Str]

[builtins fixtures/list.pyi]
[typing fixtures/typing-full.pyi]

[case testMatchProtocolAgainstOverloadWithMultipleMatchingItems]
from typing import Protocol, overload, TypeVar, Any

_T_co = TypeVar("_T_co", covariant=True)
_T = TypeVar("_T")

class SupportsRound(Protocol[_T_co]):
@overload
def __round__(self) -> int: ...
@overload
def __round__(self, __ndigits: int) -> _T_co: ...

class C:
# This matches both overload items of SupportsRound
def __round__(self, __ndigits: int = ...) -> int: ...

def round(number: SupportsRound[_T], ndigits: int) -> _T: ...

round(C(), 1)