Skip to content

Commit fe8309a

Browse files
TH3CHARLieilevkivskyi
authored andcommitted
Flatten TypeAliasType when it is aliased as a Union (#8146)
Resolves #8125 The main problem is not about flattening unions inside variants since the following code generates no error ```python from typing import Union T1 = Union[int, float] T2 = Union[Union[Union[int, float], float], Union[float, complex], complex] def foo(a: T2, b: T2) -> T2: return a + b ``` The problem, however, is because when using `TypeAliasType` to alias a Union, the `TypeAliasType` will not get flattened, so this PR fixes this.
1 parent 04366e7 commit fe8309a

File tree

4 files changed

+31
-6
lines changed

4 files changed

+31
-6
lines changed

mypy/checkexpr.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue,
2020
is_named_instance, FunctionLike,
2121
StarType, is_optional, remove_optional, is_generic_instance, get_proper_type, ProperType,
22-
get_proper_types
22+
get_proper_types, flatten_nested_unions
2323
)
2424
from mypy.nodes import (
2525
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
@@ -2589,7 +2589,9 @@ def check_op(self, method: str, base_type: Type,
25892589
left_variants = [base_type]
25902590
base_type = get_proper_type(base_type)
25912591
if isinstance(base_type, UnionType):
2592-
left_variants = [item for item in base_type.relevant_items()]
2592+
left_variants = [item for item in
2593+
flatten_nested_unions(base_type.relevant_items(),
2594+
handle_type_alias_type=True)]
25932595
right_type = self.accept(arg)
25942596

25952597
# Step 1: We first try leaving the right arguments alone and destructure
@@ -2632,8 +2634,8 @@ def check_op(self, method: str, base_type: Type,
26322634
right_type = get_proper_type(right_type)
26332635
if isinstance(right_type, UnionType):
26342636
right_variants = [(item, TempNode(item, context=context))
2635-
for item in right_type.relevant_items()]
2636-
2637+
for item in flatten_nested_unions(right_type.relevant_items(),
2638+
handle_type_alias_type=True)]
26372639
msg = self.msg.clean_copy()
26382640
msg.disable_count = 0
26392641
all_results = []

mypy/types.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2236,15 +2236,19 @@ def has_type_vars(typ: Type) -> bool:
22362236
return typ.accept(HasTypeVars())
22372237

22382238

2239-
def flatten_nested_unions(types: Iterable[Type]) -> List[Type]:
2239+
def flatten_nested_unions(types: Iterable[Type],
2240+
handle_type_alias_type: bool = False) -> List[Type]:
22402241
"""Flatten nested unions in a type list."""
22412242
# This and similar functions on unions can cause infinite recursion
22422243
# if passed a "pathological" alias like A = Union[int, A] or similar.
22432244
# TODO: ban such aliases in semantic analyzer.
22442245
flat_items = [] # type: List[Type]
2246+
if handle_type_alias_type:
2247+
types = get_proper_types(types)
22452248
for tp in types:
22462249
if isinstance(tp, ProperType) and isinstance(tp, UnionType):
2247-
flat_items.extend(flatten_nested_unions(tp.items))
2250+
flat_items.extend(flatten_nested_unions(tp.items,
2251+
handle_type_alias_type=handle_type_alias_type))
22482252
else:
22492253
flat_items.append(tp)
22502254
return flat_items

test-data/unit/check-unions.test

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,3 +1013,18 @@ y: Union[int, Dict[int, int]] = 1 if bool() else {}
10131013
u: Union[int, List[int]] = [] if bool() else 1
10141014
v: Union[int, Dict[int, int]] = {} if bool() else 1
10151015
[builtins fixtures/isinstancelist.pyi]
1016+
1017+
[case testFlattenTypeAliasWhenAliasedAsUnion]
1018+
from typing import Union
1019+
1020+
T1 = int
1021+
T2 = Union[T1, float]
1022+
T3 = Union[T2, complex]
1023+
T4 = Union[T3, int]
1024+
1025+
def foo(a: T2, b: T2) -> T2:
1026+
return a + b
1027+
1028+
def bar(a: T4, b: T4) -> T4: # test multi-level alias
1029+
return a + b
1030+
[builtins fixtures/ops.pyi]

test-data/unit/fixtures/ops.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ class float:
6464
def __truediv__(self, x: 'float') -> 'float': pass
6565
def __rtruediv__(self, x: 'float') -> 'float': pass
6666

67+
class complex:
68+
def __add__(self, x: complex) -> complex: pass
69+
def __radd__(self, x: complex) -> complex: pass
70+
6771
class BaseException: pass
6872

6973
def __print(a1=None, a2=None, a3=None, a4=None): pass

0 commit comments

Comments
 (0)