Skip to content

Force enum literals to simplify when inferring unions #7904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
since these may assume that MROs are ready.
"""

from typing import cast, Optional, List, Sequence, Set
from typing import cast, Optional, List, Sequence, Set, Dict
import sys

from mypy.types import (
Expand Down Expand Up @@ -300,6 +300,11 @@ def make_simplified_union(items: Sequence[Type],
* [int, int] -> int
* [int, Any] -> Union[int, Any] (Any types are not simplified away!)
* [Any, Any] -> Any
* [Literal[Foo.A], Literal[Foo.B]] -> Foo (assuming Foo is a enum with two variants A and B)

Note that we only collapse enum literals into the original enum when all literal variants
are present. Since enums are effectively final and there are a fixed number of possible
variants, it's safe to treat those two types as equivalent.

Note: This must NOT be used during semantic analysis, since TypeInfos may not
be fully initialized.
Expand All @@ -316,6 +321,8 @@ def make_simplified_union(items: Sequence[Type],

from mypy.subtypes import is_proper_subtype

enums_found = {} # type: Dict[str, int]
enum_max_members = {} # type: Dict[str, int]
removed = set() # type: Set[int]
for i, ti in enumerate(items):
if i in removed: continue
Expand All @@ -327,13 +334,52 @@ def make_simplified_union(items: Sequence[Type],
removed.add(j)
cbt = cbt or tj.can_be_true
cbf = cbf or tj.can_be_false

# if deleted subtypes had more general truthiness, use that
if not ti.can_be_true and cbt:
items[i] = true_or_false(ti)
items[i] = ti = true_or_false(ti)
elif not ti.can_be_false and cbf:
items[i] = true_or_false(ti)
items[i] = ti = true_or_false(ti)

# Keep track of all enum Literal types we encounter, in case
# we can coalesce them together
if isinstance(ti, LiteralType) and ti.is_enum_literal():
enum_name = ti.fallback.type.fullname()
if enum_name not in enum_max_members:
enum_max_members[enum_name] = len(get_enum_values(ti.fallback))
enums_found[enum_name] = enums_found.get(enum_name, 0) + 1
if isinstance(ti, Instance) and ti.type.is_enum:
enum_name = ti.type.fullname()
if enum_name not in enum_max_members:
enum_max_members[enum_name] = len(get_enum_values(ti))
enums_found[enum_name] = enum_max_members[enum_name]

enums_to_compress = {n for (n, c) in enums_found.items() if c >= enum_max_members[n]}
enums_encountered = set() # type: Set[str]
simplified_set = [] # type: List[ProperType]
for i, item in enumerate(items):
if i in removed:
continue

# Try seeing if this is an enum or enum literal, and if it's
# one we should be collapsing away.
if isinstance(item, LiteralType):
instance = item.fallback # type: Optional[Instance]
elif isinstance(item, Instance):
instance = item
else:
instance = None

if instance and instance.type.is_enum:
enum_name = instance.type.fullname()
if enum_name in enums_encountered:
continue
if enum_name in enums_to_compress:
simplified_set.append(instance)
enums_encountered.add(enum_name)
continue
simplified_set.append(item)

simplified_set = [items[i] for i in range(len(items)) if i not in removed]
return UnionType.make_union(simplified_set, line, column)


Expand Down Expand Up @@ -551,7 +597,7 @@ class Status(Enum):

if isinstance(typ, UnionType):
items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items]
return make_simplified_union(items)
return UnionType.make_union(items)
elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname() == target_fullname:
new_items = []
for name, symbol in typ.type.names.items():
Expand All @@ -566,7 +612,7 @@ class Status(Enum):
# only using CPython, but we might as well for the sake of full correctness.
if sys.version_info < (3, 7):
new_items.sort(key=lambda lit: lit.value)
return make_simplified_union(new_items)
return UnionType.make_union(new_items)
else:
return typ

Expand All @@ -578,7 +624,7 @@ def coerce_to_literal(typ: Type) -> ProperType:
typ = get_proper_type(typ)
if isinstance(typ, UnionType):
new_items = [coerce_to_literal(item) for item in typ.items]
return make_simplified_union(new_items)
return UnionType.make_union(new_items)
elif isinstance(typ, Instance):
if typ.last_known_value:
return typ.last_known_value
Expand Down
100 changes: 97 additions & 3 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@ elif x is Foo.C:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(x) # No output here: this branch is unreachable
reveal_type(x) # N: Revealed type is '__main__.Foo'

if Foo.A is x:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
Expand All @@ -638,6 +639,7 @@ elif Foo.C is x:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(x) # No output here: this branch is unreachable
reveal_type(x) # N: Revealed type is '__main__.Foo'

y: Foo
if y is Foo.A:
Expand All @@ -648,6 +650,7 @@ elif y is Foo.C:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(y) # No output here: this branch is unreachable
reveal_type(y) # N: Revealed type is '__main__.Foo'

if Foo.A is y:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
Expand All @@ -657,6 +660,7 @@ elif Foo.C is y:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(y) # No output here: this branch is unreachable
reveal_type(y) # N: Revealed type is '__main__.Foo'
[builtins fixtures/bool.pyi]

[case testEnumReachabilityChecksIndirect]
Expand Down Expand Up @@ -686,6 +690,8 @@ if y is x:
else:
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'

if x is z:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
Expand All @@ -703,6 +709,8 @@ else:
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(z) # N: Revealed type is '__main__.Foo*'
accepts_foo_a(z)
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(z) # N: Revealed type is '__main__.Foo*'

if y is z:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
Expand All @@ -718,6 +726,8 @@ if z is y:
else:
reveal_type(y) # No output: this branch is unreachable
reveal_type(z) # No output: this branch is unreachable
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(z) # N: Revealed type is '__main__.Foo*'
[builtins fixtures/bool.pyi]

[case testEnumReachabilityNoNarrowingForUnionMessiness]
Expand All @@ -740,13 +750,17 @@ if x is y:
else:
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'

if y is z:
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
else:
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
[builtins fixtures/bool.pyi]

[case testEnumReachabilityWithNone]
Expand All @@ -764,16 +778,19 @@ if x:
reveal_type(x) # N: Revealed type is '__main__.Foo'
else:
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'

if x is not None:
reveal_type(x) # N: Revealed type is '__main__.Foo'
else:
reveal_type(x) # N: Revealed type is 'None'
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'

if x is Foo.A:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]'
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'
[builtins fixtures/bool.pyi]

[case testEnumReachabilityWithMultipleEnums]
Expand All @@ -793,18 +810,21 @@ if x1 is Foo.A:
reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]'
reveal_type(x1) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'

x2: Union[Foo, Bar]
if x2 is Bar.A:
reveal_type(x2) # N: Revealed type is 'Literal[__main__.Bar.A]'
else:
reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]'
reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'

x3: Union[Foo, Bar]
if x3 is Foo.A or x3 is Bar.A:
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]'
else:
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]'
reveal_type(x3) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'

[builtins fixtures/bool.pyi]

Expand All @@ -823,7 +843,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int:
# E: Unsupported left operand type for + ("Empty") \
# N: Left operand is of type "Union[int, None, Empty]"
if x is _empty:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]'
reveal_type(x) # N: Revealed type is '__main__.Empty'
return 0
elif x is None:
reveal_type(x) # N: Revealed type is 'None'
Expand Down Expand Up @@ -870,7 +890,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int:
# E: Unsupported left operand type for + ("Empty") \
# N: Left operand is of type "Union[int, None, Empty]"
if x is _empty:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]'
reveal_type(x) # N: Revealed type is '__main__.Empty'
return 0
elif x is None:
reveal_type(x) # N: Revealed type is 'None'
Expand Down Expand Up @@ -899,7 +919,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int:
# E: Unsupported left operand type for + ("Empty") \
# N: Left operand is of type "Union[int, None, Empty]"
if x is _empty:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]'
reveal_type(x) # N: Revealed type is '__main__.Empty'
return 0
elif x is None:
reveal_type(x) # N: Revealed type is 'None'
Expand All @@ -908,3 +928,77 @@ def func(x: Union[int, None, Empty] = _empty) -> int:
reveal_type(x) # N: Revealed type is 'builtins.int'
return x + 2
[builtins fixtures/primitives.pyi]

[case testEnumUnionCompression]
from typing import Union
from typing_extensions import Literal
from enum import Enum

class Foo(Enum):
A = 1
B = 2
C = 3

class Bar(Enum):
X = 1
Y = 2

x1: Literal[Foo.A, Foo.B, Foo.B, Foo.B, 1, None]
assert x1 is not None
reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B], Literal[1]]'

x2: Literal[1, Foo.A, Foo.B, Foo.C, None]
assert x2 is not None
reveal_type(x2) # N: Revealed type is 'Union[Literal[1], __main__.Foo]'

x3: Literal[Foo.A, Foo.B, 1, Foo.C, Foo.C, Foo.C, None]
assert x3 is not None
reveal_type(x3) # N: Revealed type is 'Union[__main__.Foo, Literal[1]]'

x4: Literal[Foo.A, Foo.B, Foo.C, Foo.C, Foo.C, None]
assert x4 is not None
reveal_type(x4) # N: Revealed type is '__main__.Foo'

x5: Union[Literal[Foo.A], Foo, None]
assert x5 is not None
reveal_type(x5) # N: Revealed type is '__main__.Foo'

x6: Literal[Foo.A, Bar.X, Foo.B, Bar.Y, Foo.C, None]
assert x6 is not None
reveal_type(x6) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'

# TODO: We should really simplify this down into just 'Bar' as well.
no_forcing: Literal[Bar.X, Bar.X, Bar.Y]
reveal_type(no_forcing) # N: Revealed type is 'Union[Literal[__main__.Bar.X], Literal[__main__.Bar.X], Literal[__main__.Bar.Y]]'

[case testEnumUnionCompressionAssignment]
from typing_extensions import Literal
from enum import Enum

class Foo(Enum):
A = 1
B = 2

class Wrapper1:
def __init__(self, x: object, y: Foo) -> None:
if x:
if y is Foo.A:
pass
else:
pass
self.y = y
else:
self.y = y
reveal_type(self.y) # N: Revealed type is '__main__.Foo'

class Wrapper2:
def __init__(self, x: object, y: Foo) -> None:
if x:
self.y = y
else:
if y is Foo.A:
pass
else:
pass
self.y = y
reveal_type(self.y) # N: Revealed type is '__main__.Foo'