From f7bec14ff41d58a785771954bd70b7b893220279 Mon Sep 17 00:00:00 2001 From: Keenan Gugeler Date: Thu, 3 Mar 2022 16:12:28 -0500 Subject: [PATCH 1/4] use base type as hint when inferring Would help with #12268. Modified a test because the two behaviours conflict right now. --- mypy/checker.py | 14 +++++++++++++- test-data/unit/check-inference.test | 21 +++++++++++++++++++-- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 851f23185f4f..13c0d7c7b990 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2376,7 +2376,19 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type self.check_indexed_assignment(index_lvalue, rvalue, lvalue) if inferred: - rvalue_type = self.expr_checker.accept(rvalue) + type_hint = None + if isinstance(lvalue, NameExpr): + inferred_node = lvalue.node + if (lvalue.kind in (MDEF, None) and # None for Vars defined via self + len(inferred_node.info.bases) > 0): + for base in inferred_node.info.mro[1:]: + base_type, base_node = self.lvalue_type_from_base(inferred_node, base) + + if base_type: + type_hint = base_type + break + + rvalue_type = self.expr_checker.accept(rvalue, type_hint) if not inferred.is_final: rvalue_type = remove_instance_last_known_values(rvalue_type) self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue) diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 4de6e4a76f92..f2f5424a98bf 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -2813,8 +2813,8 @@ class C(A): x = ['12'] reveal_type(A.x) # N: Revealed type is "builtins.list[Any]" -reveal_type(B.x) # N: Revealed type is "builtins.list[builtins.int]" -reveal_type(C.x) # N: Revealed type is "builtins.list[builtins.str]" +reveal_type(B.x) # N: Revealed type is "builtins.list[Any]" +reveal_type(C.x) # N: Revealed type is "builtins.list[Any]" [builtins fixtures/list.pyi] @@ -3255,3 +3255,20 @@ reveal_type(x) # N: Revealed type is "builtins.bytes*" if x: reveal_type(x) # N: Revealed type is "builtins.bytes*" [builtins fixtures/dict.pyi] + +[case testInferHintFromBase] +from typing import Union, List + +class A: pass +class B: pass + +U = Union[A, B] + +class Base: + variable1 : List[U] = [] + variable2 : List[U] = [] + +class Derived(Base): + variable1 = [A()] + variable2 = variable1 +[out] From 5ff9c8bae5c039bc151e56a526e2145e944b3328 Mon Sep 17 00:00:00 2001 From: Keenan Gugeler Date: Thu, 3 Mar 2022 19:24:17 -0500 Subject: [PATCH 2/4] make sure the value is the right type --- mypy/checker.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 13c0d7c7b990..b8898f1ebf32 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2377,9 +2377,10 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type if inferred: type_hint = None - if isinstance(lvalue, NameExpr): + if isinstance(lvalue, RefExpr): inferred_node = lvalue.node - if (lvalue.kind in (MDEF, None) and # None for Vars defined via self + if (isinstance(inferred_node, Var) and + lvalue.kind in (MDEF, None) and # None for Vars defined via self len(inferred_node.info.bases) > 0): for base in inferred_node.info.mro[1:]: base_type, base_node = self.lvalue_type_from_base(inferred_node, base) From cbaebd07a2dc6ee90634e0d920b5938b560864aa Mon Sep 17 00:00:00 2001 From: Keenan Gugeler Date: Thu, 3 Mar 2022 20:16:03 -0500 Subject: [PATCH 3/4] force github to link issue Closes #12268. --- test-data/unit/check-inference.test | 1 - 1 file changed, 1 deletion(-) diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index f2f5424a98bf..8f4692dee2fc 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3271,4 +3271,3 @@ class Base: class Derived(Base): variable1 = [A()] variable2 = variable1 -[out] From d1d6c5b30052deef8c2f4fd22b5463c923bc7a0c Mon Sep 17 00:00:00 2001 From: Keenan Gugeler Date: Thu, 3 Mar 2022 20:20:25 -0500 Subject: [PATCH 4/4] add out back --- test-data/unit/check-inference.test | 1 + 1 file changed, 1 insertion(+) diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 8f4692dee2fc..f2f5424a98bf 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3271,3 +3271,4 @@ class Base: class Derived(Base): variable1 = [A()] variable2 = variable1 +[out]