diff --git a/mypy/binder.py b/mypy/binder.py index 45855aa1b9d5..109fef25ce6a 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -10,6 +10,7 @@ from mypy.subtypes import is_subtype from mypy.join import join_simple from mypy.sametypes import is_same_type +from mypy.erasetype import remove_instance_last_known_values from mypy.nodes import Expression, Var, RefExpr from mypy.literals import Key, literal, literal_hash, subkeys from mypy.nodes import IndexExpr, MemberExpr, NameExpr @@ -248,6 +249,10 @@ def assign_type(self, expr: Expression, type: Type, declared_type: Optional[Type], restrict_any: bool = False) -> None: + # We should erase last known value in binder, because if we are using it, + # it means that the target is not final, and therefore can't hold a literal. + type = remove_instance_last_known_values(type) + type = get_proper_type(type) declared_type = get_proper_type(declared_type) diff --git a/mypy/checker.py b/mypy/checker.py index 36907b43a94a..1a2512662193 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1954,6 +1954,8 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type partial_types = self.find_partial_types(var) if partial_types is not None: if not self.current_node_deferred: + # Partial type can't be final, so strip any literal values. + rvalue_type = remove_instance_last_known_values(rvalue_type) inferred_type = UnionType.make_simplified_union( [rvalue_type, NoneType()]) self.set_inferred_type(var, lvalue, inferred_type) diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 94a85c0f5097..03e93857ec38 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -3067,3 +3067,29 @@ reveal_type(Test3.FOO.name) # N: Revealed type is 'builtins.str' reveal_type(Test4.FOO.name) # N: Revealed type is 'builtins.str' reveal_type(Test5.FOO.name) # N: Revealed type is 'builtins.str' [out] + +[case testLiteralBinderLastValueErased] +# mypy: strict-equality + +from typing_extensions import Literal + +def takes_three(x: Literal[3]) -> None: ... +x: object +x = 3 + +takes_three(x) # E: Argument 1 to "takes_three" has incompatible type "int"; expected "Literal[3]" +if x == 2: # OK + ... +[builtins fixtures/bool.pyi] + +[case testLiteralBinderLastValueErasedPartialTypes] +# mypy: strict-equality + +def test() -> None: + x = None + if bool(): + x = 1 + + if x == 2: # OK + ... +[builtins fixtures/bool.pyi]