Skip to content

Commit 57de8db

Browse files
authored
Use supertype context for variable type inference (#13494)
1 parent 61a9b92 commit 57de8db

File tree

6 files changed

+146
-7
lines changed

6 files changed

+146
-7
lines changed

mypy/checker.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -1624,6 +1624,8 @@ def check_slots_definition(self, typ: Type, context: Context) -> None:
16241624

16251625
def check_match_args(self, var: Var, typ: Type, context: Context) -> None:
16261626
"""Check that __match_args__ contains literal strings"""
1627+
if not self.scope.active_class():
1628+
return
16271629
typ = get_proper_type(typ)
16281630
if not isinstance(typ, TupleType) or not all(
16291631
[is_string_literal(item) for item in typ.items]
@@ -2686,7 +2688,8 @@ def check_assignment(
26862688
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)
26872689

26882690
if inferred:
2689-
rvalue_type = self.expr_checker.accept(rvalue)
2691+
type_context = self.get_variable_type_context(inferred)
2692+
rvalue_type = self.expr_checker.accept(rvalue, type_context=type_context)
26902693
if not (
26912694
inferred.is_final
26922695
or (isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__")
@@ -2698,6 +2701,27 @@ def check_assignment(
26982701
# (type, operator) tuples for augmented assignments supported with partial types
26992702
partial_type_augmented_ops: Final = {("builtins.list", "+"), ("builtins.set", "|")}
27002703

2704+
def get_variable_type_context(self, inferred: Var) -> Type | None:
2705+
type_contexts = []
2706+
if inferred.info:
2707+
for base in inferred.info.mro[1:]:
2708+
base_type, base_node = self.lvalue_type_from_base(inferred, base)
2709+
if base_type and not (
2710+
isinstance(base_node, Var) and base_node.invalid_partial_type
2711+
):
2712+
type_contexts.append(base_type)
2713+
# Use most derived supertype as type context if available.
2714+
if not type_contexts:
2715+
return None
2716+
candidate = type_contexts[0]
2717+
for other in type_contexts:
2718+
if is_proper_subtype(other, candidate):
2719+
candidate = other
2720+
elif not is_subtype(candidate, other):
2721+
# Multiple incompatible candidates, cannot use any of them as context.
2722+
return None
2723+
return candidate
2724+
27012725
def try_infer_partial_generic_type_from_assignment(
27022726
self, lvalue: Lvalue, rvalue: Expression, op: str
27032727
) -> None:
@@ -5870,7 +5894,9 @@ def enter_partial_types(
58705894
self.msg.need_annotation_for_var(var, context, self.options.python_version)
58715895
self.partial_reported.add(var)
58725896
if var.type:
5873-
var.type = self.fixup_partial_type(var.type)
5897+
fixed = self.fixup_partial_type(var.type)
5898+
var.invalid_partial_type = fixed != var.type
5899+
var.type = fixed
58745900

58755901
def handle_partial_var_type(
58765902
self, typ: PartialType, is_lvalue: bool, node: Var, context: Context

mypy/nodes.py

+5
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,7 @@ def deserialize(cls, data: JsonDict) -> Decorator:
939939
"explicit_self_type",
940940
"is_ready",
941941
"is_inferred",
942+
"invalid_partial_type",
942943
"from_module_getattr",
943944
"has_explicit_value",
944945
"allow_incompatible_override",
@@ -975,6 +976,7 @@ class Var(SymbolNode):
975976
"from_module_getattr",
976977
"has_explicit_value",
977978
"allow_incompatible_override",
979+
"invalid_partial_type",
978980
)
979981

980982
def __init__(self, name: str, type: mypy.types.Type | None = None) -> None:
@@ -1024,6 +1026,9 @@ def __init__(self, name: str, type: mypy.types.Type | None = None) -> None:
10241026
self.has_explicit_value = False
10251027
# If True, subclasses can override this with an incompatible type.
10261028
self.allow_incompatible_override = False
1029+
# If True, this means we didn't manage to infer full type and fall back to
1030+
# something like list[Any]. We may decide to not use such types as context.
1031+
self.invalid_partial_type = False
10271032

10281033
@property
10291034
def name(self) -> str:

mypy/semanal.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -3007,7 +3007,10 @@ def process_type_annotation(self, s: AssignmentStmt) -> None:
30073007
):
30083008
self.fail("All protocol members must have explicitly declared types", s)
30093009
# Set the type if the rvalue is a simple literal (even if the above error occurred).
3010-
if len(s.lvalues) == 1 and isinstance(s.lvalues[0], RefExpr):
3010+
# We skip this step for type scope because it messes up with class attribute
3011+
# inference for literal types (also annotated and non-annotated variables at class
3012+
# scope are semantically different, so we should not souch statement type).
3013+
if len(s.lvalues) == 1 and isinstance(s.lvalues[0], RefExpr) and not self.type:
30113014
if s.lvalues[0].is_inferred_def:
30123015
s.type = self.analyze_simple_literal_type(s.rvalue, s.is_final_def)
30133016
if s.type:
@@ -3026,7 +3029,6 @@ def is_annotated_protocol_member(self, s: AssignmentStmt) -> bool:
30263029

30273030
def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Type | None:
30283031
"""Return builtins.int if rvalue is an int literal, etc.
3029-
30303032
If this is a 'Final' context, we return "Literal[...]" instead."""
30313033
if self.options.semantic_analysis_only or self.function_stack:
30323034
# Skip this if we're only doing the semantic analysis pass.

test-data/unit/check-classes.test

+3-3
Original file line numberDiff line numberDiff line change
@@ -4317,7 +4317,7 @@ class C(B):
43174317
x = object()
43184318
[out]
43194319
main:4: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int")
4320-
main:6: error: Incompatible types in assignment (expression has type "object", base class "B" defined the type as "str")
4320+
main:6: error: Incompatible types in assignment (expression has type "object", base class "A" defined the type as "int")
43214321

43224322
[case testClassOneErrorPerLine]
43234323
class A:
@@ -4327,15 +4327,15 @@ class B(A):
43274327
x = 1.0
43284328
[out]
43294329
main:4: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int")
4330-
main:5: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int")
4330+
main:5: error: Incompatible types in assignment (expression has type "float", base class "A" defined the type as "int")
43314331

43324332
[case testClassIgnoreType_RedefinedAttributeAndGrandparentAttributeTypesNotIgnored]
43334333
class A:
43344334
x = 0
43354335
class B(A):
43364336
x = '' # type: ignore
43374337
class C(B):
4338-
x = ''
4338+
x = '' # E: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int")
43394339
[out]
43404340

43414341
[case testClassIgnoreType_RedefinedAttributeTypeIgnoredInChildren]

test-data/unit/check-inference.test

+65
Original file line numberDiff line numberDiff line change
@@ -3263,3 +3263,68 @@ from typing import Dict, Iterable, Tuple, Union
32633263
def foo(x: Union[Tuple[str, Dict[str, int], str], Iterable[object]]) -> None: ...
32643264
foo(("a", {"a": "b"}, "b"))
32653265
[builtins fixtures/dict.pyi]
3266+
3267+
[case testUseSupertypeAsInferenceContext]
3268+
# flags: --strict-optional
3269+
from typing import List, Optional
3270+
3271+
class B:
3272+
x: List[Optional[int]]
3273+
3274+
class C(B):
3275+
x = [1]
3276+
3277+
reveal_type(C().x) # N: Revealed type is "builtins.list[Union[builtins.int, None]]"
3278+
[builtins fixtures/list.pyi]
3279+
3280+
[case testUseSupertypeAsInferenceContextInvalidType]
3281+
from typing import List
3282+
class P:
3283+
x: List[int]
3284+
class C(P):
3285+
x = ['a'] # E: List item 0 has incompatible type "str"; expected "int"
3286+
[builtins fixtures/list.pyi]
3287+
3288+
[case testUseSupertypeAsInferenceContextPartial]
3289+
from typing import List
3290+
3291+
class A:
3292+
x: List[str]
3293+
3294+
class B(A):
3295+
x = []
3296+
3297+
reveal_type(B().x) # N: Revealed type is "builtins.list[builtins.str]"
3298+
[builtins fixtures/list.pyi]
3299+
3300+
[case testUseSupertypeAsInferenceContextPartialError]
3301+
class A:
3302+
x = ['a', 'b']
3303+
3304+
class B(A):
3305+
x = []
3306+
x.append(2) # E: Argument 1 to "append" of "list" has incompatible type "int"; expected "str"
3307+
[builtins fixtures/list.pyi]
3308+
3309+
[case testUseSupertypeAsInferenceContextPartialErrorProperty]
3310+
from typing import List
3311+
3312+
class P:
3313+
@property
3314+
def x(self) -> List[int]: ...
3315+
class C(P):
3316+
x = []
3317+
3318+
C.x.append("no") # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
3319+
[builtins fixtures/list.pyi]
3320+
3321+
[case testUseSupertypeAsInferenceContextConflict]
3322+
from typing import List
3323+
class P:
3324+
x: List[int]
3325+
class M:
3326+
x: List[str]
3327+
class C(P, M):
3328+
x = [] # E: Need type annotation for "x" (hint: "x: List[<type>] = ...")
3329+
reveal_type(C.x) # N: Revealed type is "builtins.list[Any]"
3330+
[builtins fixtures/list.pyi]

test-data/unit/check-literal.test

+41
Original file line numberDiff line numberDiff line change
@@ -2918,3 +2918,44 @@ def incorrect_return2() -> Union[Tuple[Literal[True], int], Tuple[Literal[False]
29182918
else:
29192919
return (bool(), 'oops') # E: Incompatible return value type (got "Tuple[bool, str]", expected "Union[Tuple[Literal[True], int], Tuple[Literal[False], str]]")
29202920
[builtins fixtures/bool.pyi]
2921+
2922+
[case testLiteralSubtypeContext]
2923+
from typing_extensions import Literal
2924+
2925+
class A:
2926+
foo: Literal['bar', 'spam']
2927+
class B(A):
2928+
foo = 'spam'
2929+
2930+
reveal_type(B().foo) # N: Revealed type is "Literal['spam']"
2931+
[builtins fixtures/tuple.pyi]
2932+
2933+
[case testLiteralSubtypeContextNested]
2934+
from typing import List
2935+
from typing_extensions import Literal
2936+
2937+
class A:
2938+
foo: List[Literal['bar', 'spam']]
2939+
class B(A):
2940+
foo = ['spam']
2941+
2942+
reveal_type(B().foo) # N: Revealed type is "builtins.list[Union[Literal['bar'], Literal['spam']]]"
2943+
[builtins fixtures/tuple.pyi]
2944+
2945+
[case testLiteralSubtypeContextGeneric]
2946+
from typing_extensions import Literal
2947+
from typing import Generic, List, TypeVar
2948+
2949+
T = TypeVar("T", bound=str)
2950+
2951+
class B(Generic[T]):
2952+
collection: List[T]
2953+
word: T
2954+
2955+
class C(B[Literal["word"]]):
2956+
collection = ["word"]
2957+
word = "word"
2958+
2959+
reveal_type(C().collection) # N: Revealed type is "builtins.list[Literal['word']]"
2960+
reveal_type(C().word) # N: Revealed type is "Literal['word']"
2961+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)