Skip to content

Commit 0061d6e

Browse files
authored
Allow narrowing enum values using == (#11521)
Resolves #10915, resolves #9786 See the discussion in #10915. I'm sympathetic to the difference between identity and equality here being surprising and that mypy doesn't usually make concessions to mutability when type checking. The old test cases are pretty explicit about their intentions and are worth reading. Curious to see what people (and mypy-primer) have to say about this. Co-authored-by: hauntsaninja <>
1 parent 31708f3 commit 0061d6e

File tree

2 files changed

+42
-25
lines changed

2 files changed

+42
-25
lines changed

mypy/checker.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -5835,7 +5835,15 @@ def refine_identity_comparison_expression(
58355835
"""
58365836
should_coerce = True
58375837
if coerce_only_in_literal_context:
5838-
should_coerce = any(is_literal_type_like(operand_types[i]) for i in chain_indices)
5838+
5839+
def should_coerce_inner(typ: Type) -> bool:
5840+
typ = get_proper_type(typ)
5841+
return is_literal_type_like(typ) or (
5842+
isinstance(typ, Instance)
5843+
and typ.type.is_enum
5844+
)
5845+
5846+
should_coerce = any(should_coerce_inner(operand_types[i]) for i in chain_indices)
58395847

58405848
target: Type | None = None
58415849
possible_target_indices = []

test-data/unit/check-narrowing.test

+33-24
Original file line numberDiff line numberDiff line change
@@ -703,47 +703,47 @@ class FlipFlopStr:
703703
def mutate(self) -> None:
704704
self.state = "state-2" if self.state == "state-1" else "state-1"
705705

706-
def test1(switch: FlipFlopEnum) -> None:
706+
707+
def test1(switch: FlipFlopStr) -> None:
707708
# Naively, we might assume the 'assert' here would narrow the type to
708-
# Literal[State.A]. However, doing this ends up breaking a fair number of real-world
709+
# Literal["state-1"]. However, doing this ends up breaking a fair number of real-world
709710
# code (usually test cases) that looks similar to this function: e.g. checks
710711
# to make sure a field was mutated to some particular value.
711712
#
712713
# And since mypy can't really reason about state mutation, we take a conservative
713714
# approach and avoid narrowing anything here.
714715

715-
assert switch.state == State.A
716-
reveal_type(switch.state) # N: Revealed type is "__main__.State"
716+
assert switch.state == "state-1"
717+
reveal_type(switch.state) # N: Revealed type is "builtins.str"
717718

718719
switch.mutate()
719720

720-
assert switch.state == State.B
721-
reveal_type(switch.state) # N: Revealed type is "__main__.State"
721+
assert switch.state == "state-2"
722+
reveal_type(switch.state) # N: Revealed type is "builtins.str"
722723

723724
def test2(switch: FlipFlopEnum) -> None:
724-
# So strictly speaking, we ought to do the same thing with 'is' comparisons
725-
# for the same reasons as above. But in practice, not too many people seem to
726-
# know that doing 'some_enum is MyEnum.Value' is idiomatic. So in practice,
727-
# this is probably good enough for now.
725+
# This is the same thing as 'test1', except we use enums, which we allow to be narrowed
726+
# to literals.
728727

729-
assert switch.state is State.A
728+
assert switch.state == State.A
730729
reveal_type(switch.state) # N: Revealed type is "Literal[__main__.State.A]"
731730

732731
switch.mutate()
733732

734-
assert switch.state is State.B # E: Non-overlapping identity check (left operand type: "Literal[State.A]", right operand type: "Literal[State.B]")
733+
assert switch.state == State.B # E: Non-overlapping equality check (left operand type: "Literal[State.A]", right operand type: "Literal[State.B]")
735734
reveal_type(switch.state) # E: Statement is unreachable
736735

737-
def test3(switch: FlipFlopStr) -> None:
738-
# This is the same thing as 'test1', except we try using str literals.
736+
def test3(switch: FlipFlopEnum) -> None:
737+
# Same thing, but using 'is' comparisons. Previously mypy's behaviour differed
738+
# here, narrowing when using 'is', but not when using '=='.
739739

740-
assert switch.state == "state-1"
741-
reveal_type(switch.state) # N: Revealed type is "builtins.str"
740+
assert switch.state is State.A
741+
reveal_type(switch.state) # N: Revealed type is "Literal[__main__.State.A]"
742742

743743
switch.mutate()
744744

745-
assert switch.state == "state-2"
746-
reveal_type(switch.state) # N: Revealed type is "builtins.str"
745+
assert switch.state is State.B # E: Non-overlapping identity check (left operand type: "Literal[State.A]", right operand type: "Literal[State.B]")
746+
reveal_type(switch.state) # E: Statement is unreachable
747747
[builtins fixtures/primitives.pyi]
748748

749749
[case testNarrowingEqualityRequiresExplicitStrLiteral]
@@ -795,6 +795,7 @@ reveal_type(x_union) # N: Revealed type is "Union[Literal['A'], Literal['B'
795795

796796
[case testNarrowingEqualityRequiresExplicitEnumLiteral]
797797
# flags: --strict-optional
798+
from typing import Union
798799
from typing_extensions import Literal, Final
799800
from enum import Enum
800801

@@ -805,26 +806,34 @@ class Foo(Enum):
805806
A_final: Final = Foo.A
806807
A_literal: Literal[Foo.A]
807808

808-
# See comments in testNarrowingEqualityRequiresExplicitStrLiteral and
809-
# testNarrowingEqualityFlipFlop for more on why we can't narrow here.
809+
# Note this is unlike testNarrowingEqualityRequiresExplicitStrLiteral
810+
# See also testNarrowingEqualityFlipFlop
810811
x1: Foo
811812
if x1 == Foo.A:
812-
reveal_type(x1) # N: Revealed type is "__main__.Foo"
813+
reveal_type(x1) # N: Revealed type is "Literal[__main__.Foo.A]"
813814
else:
814-
reveal_type(x1) # N: Revealed type is "__main__.Foo"
815+
reveal_type(x1) # N: Revealed type is "Literal[__main__.Foo.B]"
815816

816817
x2: Foo
817818
if x2 == A_final:
818-
reveal_type(x2) # N: Revealed type is "__main__.Foo"
819+
reveal_type(x2) # N: Revealed type is "Literal[__main__.Foo.A]"
819820
else:
820-
reveal_type(x2) # N: Revealed type is "__main__.Foo"
821+
reveal_type(x2) # N: Revealed type is "Literal[__main__.Foo.B]"
821822

822823
# But we let this narrow since there's an explicit literal in the RHS.
823824
x3: Foo
824825
if x3 == A_literal:
825826
reveal_type(x3) # N: Revealed type is "Literal[__main__.Foo.A]"
826827
else:
827828
reveal_type(x3) # N: Revealed type is "Literal[__main__.Foo.B]"
829+
830+
831+
class SingletonFoo(Enum):
832+
A = "A"
833+
834+
def bar(x: Union[SingletonFoo, Foo], y: SingletonFoo) -> None:
835+
if x == y:
836+
reveal_type(x) # N: Revealed type is "Literal[__main__.SingletonFoo.A]"
828837
[builtins fixtures/primitives.pyi]
829838

830839
[case testNarrowingEqualityDisabledForCustomEquality]

0 commit comments

Comments
 (0)