diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 35c58478ce1e..9b15075da85e 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -19,7 +19,7 @@ PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue, is_named_instance, FunctionLike, StarType, is_optional, remove_optional, is_generic_instance, get_proper_type, ProperType, - get_proper_types + get_proper_types, flatten_nested_unions ) from mypy.nodes import ( NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr, @@ -2521,7 +2521,9 @@ def check_op(self, method: str, base_type: Type, left_variants = [base_type] base_type = get_proper_type(base_type) if isinstance(base_type, UnionType): - left_variants = [item for item in base_type.relevant_items()] + left_variants = [item for item in + flatten_nested_unions(base_type.relevant_items(), + handle_type_alias_type=True)] right_type = self.accept(arg) # Step 1: We first try leaving the right arguments alone and destructure @@ -2564,8 +2566,8 @@ def check_op(self, method: str, base_type: Type, right_type = get_proper_type(right_type) if isinstance(right_type, UnionType): right_variants = [(item, TempNode(item, context=context)) - for item in right_type.relevant_items()] - + for item in flatten_nested_unions(right_type.relevant_items(), + handle_type_alias_type=True)] msg = self.msg.clean_copy() msg.disable_count = 0 all_results = [] diff --git a/mypy/types.py b/mypy/types.py index ae678acedb3a..2890cc35b22b 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2231,15 +2231,19 @@ def has_type_vars(typ: Type) -> bool: return typ.accept(HasTypeVars()) -def flatten_nested_unions(types: Iterable[Type]) -> List[Type]: +def flatten_nested_unions(types: Iterable[Type], + handle_type_alias_type: bool = False) -> List[Type]: """Flatten nested unions in a type list.""" # This and similar functions on unions can cause infinite recursion # if passed a "pathological" alias like A = Union[int, A] or similar. # TODO: ban such aliases in semantic analyzer. flat_items = [] # type: List[Type] + if handle_type_alias_type: + types = get_proper_types(types) for tp in types: if isinstance(tp, ProperType) and isinstance(tp, UnionType): - flat_items.extend(flatten_nested_unions(tp.items)) + flat_items.extend(flatten_nested_unions(tp.items, + handle_type_alias_type=handle_type_alias_type)) else: flat_items.append(tp) return flat_items diff --git a/test-data/unit/check-unions.test b/test-data/unit/check-unions.test index ed2b415e8f99..92e886fee419 100644 --- a/test-data/unit/check-unions.test +++ b/test-data/unit/check-unions.test @@ -1013,3 +1013,18 @@ y: Union[int, Dict[int, int]] = 1 if bool() else {} u: Union[int, List[int]] = [] if bool() else 1 v: Union[int, Dict[int, int]] = {} if bool() else 1 [builtins fixtures/isinstancelist.pyi] + +[case testFlattenTypeAliasWhenAliasedAsUnion] +from typing import Union + +T1 = int +T2 = Union[T1, float] +T3 = Union[T2, complex] +T4 = Union[T3, int] + +def foo(a: T2, b: T2) -> T2: + return a + b + +def bar(a: T4, b: T4) -> T4: # test multi-level alias + return a + b +[builtins fixtures/ops.pyi] diff --git a/test-data/unit/fixtures/ops.pyi b/test-data/unit/fixtures/ops.pyi index 34cfb176243e..0c3497b1667f 100644 --- a/test-data/unit/fixtures/ops.pyi +++ b/test-data/unit/fixtures/ops.pyi @@ -64,6 +64,10 @@ class float: def __truediv__(self, x: 'float') -> 'float': pass def __rtruediv__(self, x: 'float') -> 'float': pass +class complex: + def __add__(self, x: complex) -> complex: pass + def __radd__(self, x: complex) -> complex: pass + class BaseException: pass def __print(a1=None, a2=None, a3=None, a4=None): pass