From b1d5b923615d8b6d348227e5ae52e909b35400e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roberto=20Fern=C3=A1ndez=20Iglesias?= Date: Sat, 1 Jun 2024 14:46:41 +0200 Subject: [PATCH 1/3] Fix union inference of generic class and its generic type --- mypy/constraints.py | 16 ++++++++++++++- test-data/unit/check-inference.test | 30 ++++++++++++++++++++++------- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index cdfa39ac45f3..922f26a9de17 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -390,10 +390,24 @@ def _infer_constraints( # When the template is a union, we are okay with leaving some # type variables indeterminate. This helps with some special # cases, though this isn't very principled. + + def _is_item_being_overlaped_by_other(item: Type) -> bool: + # It returns true if the item is an argument of other item + # that is subtype of the actual type + return any( + isinstance(p_type := get_proper_type(item_to_compare), Instance) + and mypy.subtypes.is_subtype(actual, erase_typevars(p_type)) + and item in p_type.args + for item_to_compare in template.items + if item is not item_to_compare + ) + result = any_constraints( [ infer_constraints_if_possible(t_item, actual, direction) - for t_item in template.items + for t_item in [ + item for item in template.items if not _is_item_being_overlaped_by_other(item) + ] ], eager=False, ) diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 08b53ab16972..ca05cb73c335 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -873,13 +873,7 @@ def g(x: Union[T, List[T]]) -> List[T]: pass def h(x: List[str]) -> None: pass g('a')() # E: "List[str]" not callable -# The next line is a case where there are multiple ways to satisfy a constraint -# involving a Union. Either T = List[str] or T = str would turn out to be valid, -# but mypy doesn't know how to branch on these two options (and potentially have -# to backtrack later) and defaults to T = Never. The result is an -# awkward error message. Either a better error message, or simply accepting the -# call, would be preferable here. -g(['a']) # E: Argument 1 to "g" has incompatible type "List[str]"; expected "List[Never]" +g(['a']) h(g(['a'])) @@ -891,6 +885,28 @@ i(b, a, b) i(a, b, b) # E: Argument 1 to "i" has incompatible type "List[int]"; expected "List[str]" [builtins fixtures/list.pyi] +[case testUnionInferenceOfGenericClassAndItsGenericType] +from typing import Generic, TypeVar, Union + +T = TypeVar('T') + +class GenericClass(Generic[T]): + def __init__(self, value: T) -> None: + self.value = value + +def method_with_union(arg: Union[GenericClass[T], T]) -> GenericClass[T]: + if not isinstance(arg, GenericClass): + arg = GenericClass(arg) + return arg + +result_1 = method_with_union(GenericClass("test")) +reveal_type(result_1) # N: Revealed type is "__main__.GenericClass[builtins.str]" + +result_2 = method_with_union("test") +reveal_type(result_2) # N: Revealed type is "__main__.GenericClass[builtins.str]" + +[builtins fixtures/isinstance.pyi] + [case testCallableListJoinInference] from typing import Any, Callable From 2314852da781cca51f5306d2cb046de3c88719ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roberto=20Fern=C3=A1ndez=20Iglesias?= Date: Sun, 2 Jun 2024 17:13:26 +0200 Subject: [PATCH 2/3] Improve inference of union of generic types when one of the types is the generic type of the other --- mypy/constraints.py | 69 ++++++++++++++++++++++++----- test-data/unit/check-inference.test | 27 ++++++++++- 2 files changed, 84 insertions(+), 12 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 922f26a9de17..5f4c3c5437b8 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -271,7 +271,11 @@ def infer_constraints_for_callable( def infer_constraints( - template: Type, actual: Type, direction: int, skip_neg_op: bool = False + template: Type, + actual: Type, + direction: int, + skip_neg_op: bool = False, + can_have_union_overlaping: bool = True, ) -> list[Constraint]: """Infer type constraints. @@ -311,11 +315,15 @@ def infer_constraints( res = _infer_constraints(template, actual, direction, skip_neg_op) type_state.inferring.pop() return res - return _infer_constraints(template, actual, direction, skip_neg_op) + return _infer_constraints(template, actual, direction, skip_neg_op, can_have_union_overlaping) def _infer_constraints( - template: Type, actual: Type, direction: int, skip_neg_op: bool + template: Type, + actual: Type, + direction: int, + skip_neg_op: bool, + can_have_union_overlaping: bool = True, ) -> list[Constraint]: orig_template = template template = get_proper_type(template) @@ -368,8 +376,41 @@ def _infer_constraints( return res if direction == SUPERTYPE_OF and isinstance(actual, UnionType): res = [] + + def _can_have_overlaping(_item: Type, _actual: UnionType) -> bool: + # There is a special overlaping case, where we have a Union of where two types + # are the same, but one of them contains the other. + # For example, we have Union[Sequence[T], Sequence[Sequence[T]]] + # In this case, only the second one can have overlaping because it contains the other. + # So, in case of list[list[int]], second one would be chosen. + if isinstance(p_item := get_proper_type(_item), Instance) and p_item.args: + other_items = [o_item for o_item in _actual.items if o_item is not a_item] + + if len(other_items) == 1 and other_items[0] in p_item.args: + return True + + if len(other_items) > 1: + union_args = [ + p_arg + for arg in p_item.args + if isinstance(p_arg := get_proper_type(arg), UnionType) + ] + + for union_arg in union_args: + if all(o_item in union_arg.items for o_item in other_items): + return True + + return False + for a_item in actual.items: - res.extend(infer_constraints(orig_template, a_item, direction)) + res.extend( + infer_constraints( + orig_template, + a_item, + direction, + can_have_union_overlaping=_can_have_overlaping(a_item, actual), + ) + ) return res # Now the potential subtype is known not to be a Union or a type @@ -391,22 +432,28 @@ def _infer_constraints( # type variables indeterminate. This helps with some special # cases, though this isn't very principled. - def _is_item_being_overlaped_by_other(item: Type) -> bool: - # It returns true if the item is an argument of other item + def _is_item_overlaping_actual_type(_item: Type) -> bool: + # Overlaping occurs when we have a Union where two types are + # compatible and the more generic one is chosen. + # For example, in Union[T, Sequence[T]], we have to choose + # Sequence[T] if actual type is list[int]. + # This returns true if the item is an argument of other item # that is subtype of the actual type return any( - isinstance(p_type := get_proper_type(item_to_compare), Instance) - and mypy.subtypes.is_subtype(actual, erase_typevars(p_type)) - and item in p_type.args + isinstance(p_item_to_compare := get_proper_type(item_to_compare), Instance) + and mypy.subtypes.is_subtype(actual, erase_typevars(p_item_to_compare)) + and _item in p_item_to_compare.args for item_to_compare in template.items - if item is not item_to_compare + if _item is not item_to_compare ) result = any_constraints( [ infer_constraints_if_possible(t_item, actual, direction) for t_item in [ - item for item in template.items if not _is_item_being_overlaped_by_other(item) + item + for item in template.items + if not (can_have_union_overlaping and _is_item_overlaping_actual_type(item)) ] ], eager=False, diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index ca05cb73c335..96d42b47d7ee 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -885,7 +885,7 @@ i(b, a, b) i(a, b, b) # E: Argument 1 to "i" has incompatible type "List[int]"; expected "List[str]" [builtins fixtures/list.pyi] -[case testUnionInferenceOfGenericClassAndItsGenericType] +[case testInferenceOfUnionOfGenericClassAndItsGenericType] from typing import Generic, TypeVar, Union T = TypeVar('T') @@ -907,6 +907,31 @@ reveal_type(result_2) # N: Revealed type is "__main__.GenericClass[builtins.str] [builtins fixtures/isinstance.pyi] +[case testInferenceOfUnionOfSequenceOfAnyAndSequenceOfSequence] +from typing import Sequence, Iterable, TypeVar, Union + +T = TypeVar("T") +S = TypeVar("S") + +def sub_method(value: Union[S, Iterable[S]]) -> Iterable[S]: + pass + +def method(value: Union[Sequence[T], Sequence[Sequence[T]]]) -> None: + reveal_type(sub_method(value)) # N: Revealed type is "typing.Iterable[typing.Sequence[T`-1]]" + +[case testInferenceOfUnionOfUnionWithSequenceAndSequenceOfThatUnion] +from typing import Sequence, Iterable, TypeVar, Union + +T = Union[str, Sequence[int]] +S = TypeVar("S", bound=T) + +def sub_method(value: Union[S, Iterable[S]]) -> Iterable[S]: + pass + +def method(value: Union[T, Sequence[T]]) -> None: + reveal_type(sub_method(value)) # N: Revealed type is "typing.Iterable[Union[builtins.str, typing.Sequence[builtins.int]]]" + + [case testCallableListJoinInference] from typing import Any, Callable From 133c49b458e980b5ad154d4050b681fb2ac4eafe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roberto=20Fern=C3=A1ndez=20Iglesias?= Date: Sat, 26 Oct 2024 11:57:45 +0200 Subject: [PATCH 3/3] Fix 'overlapping' typo --- mypy/constraints.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 1a9b7c9f0ab6..e580aca80f24 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -278,7 +278,7 @@ def infer_constraints( actual: Type, direction: int, skip_neg_op: bool = False, - can_have_union_overlaping: bool = True, + can_have_union_overlapping: bool = True, ) -> list[Constraint]: """Infer type constraints. @@ -318,7 +318,7 @@ def infer_constraints( res = _infer_constraints(template, actual, direction, skip_neg_op) type_state.inferring.pop() return res - return _infer_constraints(template, actual, direction, skip_neg_op, can_have_union_overlaping) + return _infer_constraints(template, actual, direction, skip_neg_op, can_have_union_overlapping) def _infer_constraints( @@ -326,7 +326,7 @@ def _infer_constraints( actual: Type, direction: int, skip_neg_op: bool, - can_have_union_overlaping: bool = True, + can_have_union_overlapping: bool = True, ) -> list[Constraint]: orig_template = template template = get_proper_type(template) @@ -380,11 +380,11 @@ def _infer_constraints( if direction == SUPERTYPE_OF and isinstance(actual, UnionType): res = [] - def _can_have_overlaping(_item: Type, _actual: UnionType) -> bool: - # There is a special overlaping case, where we have a Union of where two types + def _can_have_overlapping(_item: Type, _actual: UnionType) -> bool: + # There is a special overlapping case, where we have a Union of where two types # are the same, but one of them contains the other. # For example, we have Union[Sequence[T], Sequence[Sequence[T]]] - # In this case, only the second one can have overlaping because it contains the other. + # In this case, only the second one can have overlapping because it contains the other. # So, in case of list[list[int]], second one would be chosen. if isinstance(p_item := get_proper_type(_item), Instance) and p_item.args: other_items = [o_item for o_item in _actual.items if o_item is not a_item] @@ -411,7 +411,7 @@ def _can_have_overlaping(_item: Type, _actual: UnionType) -> bool: orig_template, a_item, direction, - can_have_union_overlaping=_can_have_overlaping(a_item, actual), + can_have_union_overlapping=_can_have_overlapping(a_item, actual), ) ) return res @@ -435,8 +435,8 @@ def _can_have_overlaping(_item: Type, _actual: UnionType) -> bool: # type variables indeterminate. This helps with some special # cases, though this isn't very principled. - def _is_item_overlaping_actual_type(_item: Type) -> bool: - # Overlaping occurs when we have a Union where two types are + def _is_item_overlapping_actual_type(_item: Type) -> bool: + # Overlapping occurs when we have a Union where two types are # compatible and the more generic one is chosen. # For example, in Union[T, Sequence[T]], we have to choose # Sequence[T] if actual type is list[int]. @@ -456,7 +456,7 @@ def _is_item_overlaping_actual_type(_item: Type) -> bool: for t_item in [ item for item in template.items - if not (can_have_union_overlaping and _is_item_overlaping_actual_type(item)) + if not (can_have_union_overlapping and _is_item_overlapping_actual_type(item)) ] ], eager=False,