Skip to content

Commit 1aba774

Browse files
committed
Fix crash due to checking type variable values too early
Move type variable checks which use subtype and type sameness checks to happen at the end of semantic analysis. The implementation also adds the concept of priorities to semantic analysis patch callbacks. Callback calls are sorted by the priority. We resolve forward references and calculate fallbacks before checking type variable values, as otherwise the latter could see incomplete types and crash. Fixes #4200.
1 parent f7f35e4 commit 1aba774

8 files changed

+129
-77
lines changed

mypy/build.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1853,14 +1853,14 @@ def parse_file(self) -> None:
18531853

18541854
def semantic_analysis(self) -> None:
18551855
assert self.tree is not None, "Internal error: method must be called on parsed file only"
1856-
patches = [] # type: List[Callable[[], None]]
1856+
patches = [] # type: List[Tuple[int, Callable[[], None]]]
18571857
with self.wrap_context():
18581858
self.manager.semantic_analyzer.visit_file(self.tree, self.xpath, self.options, patches)
18591859
self.patches = patches
18601860

18611861
def semantic_analysis_pass_three(self) -> None:
18621862
assert self.tree is not None, "Internal error: method must be called on parsed file only"
1863-
patches = [] # type: List[Callable[[], None]]
1863+
patches = [] # type: List[Tuple[int, Callable[[], None]]]
18641864
with self.wrap_context():
18651865
self.manager.semantic_analyzer_pass3.visit_file(self.tree, self.xpath,
18661866
self.options, patches)
@@ -1869,7 +1869,8 @@ def semantic_analysis_pass_three(self) -> None:
18691869
self.patches = patches + self.patches
18701870

18711871
def semantic_analysis_apply_patches(self) -> None:
1872-
for patch_func in self.patches:
1872+
patches_by_priority = sorted(self.patches, key=lambda x: x[0])
1873+
for priority, patch_func in patches_by_priority:
18731874
patch_func()
18741875

18751876
def type_check_first_pass(self) -> None:

mypy/semanal.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
from mypy.plugin import Plugin, ClassDefContext, SemanticAnalyzerPluginInterface
8585
from mypy import join
8686
from mypy.util import get_prefix
87+
from mypy.semanal_shared import PRIORITY_FALLBACKS
8788

8889

8990
T = TypeVar('T')
@@ -258,11 +259,12 @@ def __init__(self,
258259
self.recurse_into_functions = True
259260

260261
def visit_file(self, file_node: MypyFile, fnam: str, options: Options,
261-
patches: List[Callable[[], None]]) -> None:
262+
patches: List[Tuple[int, Callable[[], None]]]) -> None:
262263
"""Run semantic analysis phase 2 over a file.
263264
264-
Add callbacks by mutating the patches list argument. They will be called
265-
after all semantic analysis phases but before type checking.
265+
Add (priority, callback) pairs by mutating the 'patches' list argument. They
266+
will be called after all semantic analysis phases but before type checking,
267+
lowest priority values first.
266268
"""
267269
self.recurse_into_functions = True
268270
self.options = options
@@ -2454,7 +2456,7 @@ def patch() -> None:
24542456
# We can't calculate the complete fallback type until after semantic
24552457
# analysis, since otherwise MROs might be incomplete. Postpone a callback
24562458
# function that patches the fallback.
2457-
self.patches.append(patch)
2459+
self.patches.append((PRIORITY_FALLBACKS, patch))
24582460

24592461
def add_field(var: Var, is_initialized_in_class: bool = False,
24602462
is_property: bool = False) -> None:
@@ -2693,7 +2695,7 @@ def patch() -> None:
26932695
# We can't calculate the complete fallback type until after semantic
26942696
# analysis, since otherwise MROs might be incomplete. Postpone a callback
26952697
# function that patches the fallback.
2696-
self.patches.append(patch)
2698+
self.patches.append((PRIORITY_FALLBACKS, patch))
26972699
return info
26982700

26992701
def check_classvar(self, s: AssignmentStmt) -> None:

mypy/semanal_pass3.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111

1212
from collections import OrderedDict
13-
from typing import Dict, List, Callable, Optional, Union, Set, cast
13+
from typing import Dict, List, Callable, Optional, Union, Set, cast, Tuple
1414

1515
from mypy import messages, experiments
1616
from mypy.nodes import (
@@ -26,8 +26,9 @@
2626
from mypy.errors import Errors, report_internal_error
2727
from mypy.options import Options
2828
from mypy.traverser import TraverserVisitor
29-
from mypy.typeanal import TypeAnalyserPass3, collect_any_types
29+
from mypy.typeanal import TypeAnalyserPass3, collect_any_types, TypeVariableChecker
3030
from mypy.typevars import has_no_typevars
31+
from mypy.semanal_shared import PRIORITY_FORWARD_REF, PRIORITY_TYPEVAR_VALUES
3132
import mypy.semanal
3233

3334

@@ -48,7 +49,7 @@ def __init__(self, modules: Dict[str, MypyFile], errors: Errors,
4849
self.recurse_into_functions = True
4950

5051
def visit_file(self, file_node: MypyFile, fnam: str, options: Options,
51-
patches: List[Callable[[], None]]) -> None:
52+
patches: List[Tuple[int, Callable[[], None]]]) -> None:
5253
self.recurse_into_functions = True
5354
self.errors.set_file(fnam, file_node.fullname())
5455
self.options = options
@@ -349,12 +350,7 @@ def analyze(self, type: Optional[Type], node: Union[Node, SymbolTableNode],
349350
analyzer = self.make_type_analyzer(indicator)
350351
type.accept(analyzer)
351352
self.check_for_omitted_generics(type)
352-
if indicator.get('forward') or indicator.get('synthetic'):
353-
def patch() -> None:
354-
self.perform_transform(node,
355-
lambda tp: tp.accept(ForwardReferenceResolver(self.fail,
356-
node, warn)))
357-
self.patches.append(patch)
353+
self.generate_type_patches(node, indicator, warn)
358354

359355
def analyze_types(self, types: List[Type], node: Node) -> None:
360356
# Similar to above but for nodes with multiple types.
@@ -363,12 +359,24 @@ def analyze_types(self, types: List[Type], node: Node) -> None:
363359
analyzer = self.make_type_analyzer(indicator)
364360
type.accept(analyzer)
365361
self.check_for_omitted_generics(type)
362+
self.generate_type_patches(node, indicator, warn=False)
363+
364+
def generate_type_patches(self,
365+
node: Union[Node, SymbolTableNode],
366+
indicator: Dict[str, bool],
367+
warn: bool) -> None:
366368
if indicator.get('forward') or indicator.get('synthetic'):
367369
def patch() -> None:
368370
self.perform_transform(node,
369371
lambda tp: tp.accept(ForwardReferenceResolver(self.fail,
370-
node, warn=False)))
371-
self.patches.append(patch)
372+
node, warn)))
373+
self.patches.append((PRIORITY_FORWARD_REF, patch))
374+
if indicator.get('typevar'):
375+
def patch() -> None:
376+
self.perform_transform(node,
377+
lambda tp: tp.accept(TypeVariableChecker(self.fail)))
378+
379+
self.patches.append((PRIORITY_TYPEVAR_VALUES, patch))
372380

373381
def analyze_info(self, info: TypeInfo) -> None:
374382
# Similar to above but for nodes with synthetic TypeInfos (NamedTuple and NewType).
@@ -387,7 +395,8 @@ def make_type_analyzer(self, indicator: Dict[str, bool]) -> TypeAnalyserPass3:
387395
self.sem.plugin,
388396
self.options,
389397
self.is_typeshed_file,
390-
indicator)
398+
indicator,
399+
self.patches)
391400

392401
def check_for_omitted_generics(self, typ: Type) -> None:
393402
if not self.options.disallow_any_generics or self.is_typeshed_file:

mypy/semanal_shared.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Shared definitions used by different parts of semantic analysis."""
2+
3+
# Priorities for ordering of patches within the final "patch" phase of semantic analysis
4+
# (after pass 3):
5+
6+
# Fix forward references (needs to happen first)
7+
PRIORITY_FORWARD_REF = 0
8+
# Fix fallbacks (does joins)
9+
PRIORITY_FALLBACKS = 1
10+
# Checks type var values (does subtype checks)
11+
PRIORITY_TYPEVAR_VALUES = 2

mypy/typeanal.py

+71-56
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Semantic analysis of types"""
22

33
from collections import OrderedDict
4-
from typing import Callable, List, Optional, Set, Tuple, Iterator, TypeVar, Iterable, Dict
4+
from typing import Callable, List, Optional, Set, Tuple, Iterator, TypeVar, Iterable, Dict, Union
55
from itertools import chain
66

77
from contextlib import contextmanager
@@ -14,14 +14,15 @@
1414
Type, UnboundType, TypeVarType, TupleType, TypedDictType, UnionType, Instance, AnyType,
1515
CallableType, NoneTyp, DeletedType, TypeList, TypeVarDef, TypeVisitor, SyntheticTypeVisitor,
1616
StarType, PartialType, EllipsisType, UninhabitedType, TypeType, get_typ_args, set_typ_args,
17-
CallableArgument, get_type_vars, TypeQuery, union_items, TypeOfAny, ForwardRef, Overloaded
17+
CallableArgument, get_type_vars, TypeQuery, union_items, TypeOfAny, ForwardRef, Overloaded,
18+
TypeTranslator
1819
)
1920

2021
from mypy.nodes import (
2122
TVAR, TYPE_ALIAS, UNBOUND_IMPORTED, TypeInfo, Context, SymbolTableNode, Var, Expression,
2223
IndexExpr, RefExpr, nongen_builtins, check_arg_names, check_arg_kinds, ARG_POS, ARG_NAMED,
2324
ARG_OPT, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2, TypeVarExpr, FuncDef, CallExpr, NameExpr,
24-
Decorator
25+
Decorator, Node
2526
)
2627
from mypy.tvar_scope import TypeVarScope
2728
from mypy.sametypes import is_same_type
@@ -656,7 +657,8 @@ def __init__(self,
656657
plugin: Plugin,
657658
options: Options,
658659
is_typeshed_stub: bool,
659-
indicator: Dict[str, bool]) -> None:
660+
indicator: Dict[str, bool],
661+
patches: List[Tuple[int, Callable[[], None]]]) -> None:
660662
self.lookup_func = lookup_func
661663
self.lookup_fqn_func = lookup_fqn_func
662664
self.fail = fail_func
@@ -665,6 +667,7 @@ def __init__(self,
665667
self.plugin = plugin
666668
self.is_typeshed_stub = is_typeshed_stub
667669
self.indicator = indicator
670+
self.patches = patches
668671

669672
def visit_instance(self, t: Instance) -> None:
670673
info = t.type
@@ -707,64 +710,21 @@ def visit_instance(self, t: Instance) -> None:
707710
t.args = [AnyType(TypeOfAny.from_error) for _ in info.type_vars]
708711
t.invalid = True
709712
elif info.defn.type_vars:
710-
# Check type argument values.
711-
# TODO: Calling is_subtype and is_same_types in semantic analysis is a bad idea
712-
for (i, arg), tvar in zip(enumerate(t.args), info.defn.type_vars):
713-
if tvar.values:
714-
if isinstance(arg, TypeVarType):
715-
arg_values = arg.values
716-
if not arg_values:
717-
self.fail('Type variable "{}" not valid as type '
718-
'argument value for "{}"'.format(
719-
arg.name, info.name()), t)
720-
continue
721-
else:
722-
arg_values = [arg]
723-
self.check_type_var_values(info, arg_values, tvar.name, tvar.values, i + 1, t)
724-
# TODO: These hacks will be not necessary when this will be moved to later stage.
725-
arg = self.resolve_type(arg)
726-
bound = self.resolve_type(tvar.upper_bound)
727-
if not is_subtype(arg, bound):
728-
self.fail('Type argument "{}" of "{}" must be '
729-
'a subtype of "{}"'.format(
730-
arg, info.name(), bound), t)
713+
# Check type argument values. This is postponed to the end of semantic analysis
714+
# since we need full MROs and resolved forward references.
715+
for tvar in info.defn.type_vars:
716+
if (tvar.values
717+
or not isinstance(tvar.upper_bound, Instance)
718+
or tvar.upper_bound.type.fullname() != 'builtins.object'):
719+
# Some restrictions on type variable. These can only be checked later
720+
# after we have final MROs and forward references have been resolved.
721+
self.indicator['typevar'] = True
731722
for arg in t.args:
732723
arg.accept(self)
733724
if info.is_newtype:
734725
for base in info.bases:
735726
base.accept(self)
736727

737-
def check_type_var_values(self, type: TypeInfo, actuals: List[Type], arg_name: str,
738-
valids: List[Type], arg_number: int, context: Context) -> None:
739-
for actual in actuals:
740-
actual = self.resolve_type(actual)
741-
if (not isinstance(actual, AnyType) and
742-
not any(is_same_type(actual, self.resolve_type(value))
743-
for value in valids)):
744-
if len(actuals) > 1 or not isinstance(actual, Instance):
745-
self.fail('Invalid type argument value for "{}"'.format(
746-
type.name()), context)
747-
else:
748-
class_name = '"{}"'.format(type.name())
749-
actual_type_name = '"{}"'.format(actual.type.name())
750-
self.fail(messages.INCOMPATIBLE_TYPEVAR_VALUE.format(
751-
arg_name, class_name, actual_type_name), context)
752-
753-
def resolve_type(self, tp: Type) -> Type:
754-
# This helper is only needed while is_subtype and is_same_type are
755-
# called in third pass. This can be removed when TODO in visit_instance is fixed.
756-
if isinstance(tp, ForwardRef):
757-
if tp.resolved is None:
758-
return tp.unbound
759-
tp = tp.resolved
760-
if isinstance(tp, Instance) and tp.type.replaced:
761-
replaced = tp.type.replaced
762-
if replaced.tuple_type:
763-
tp = replaced.tuple_type
764-
if replaced.typeddict_type:
765-
tp = replaced.typeddict_type
766-
return tp
767-
768728
def visit_callable_type(self, t: CallableType) -> None:
769729
t.ret_type.accept(self)
770730
for arg_type in t.arg_types:
@@ -1036,3 +996,58 @@ def make_optional_type(t: Type) -> Type:
1036996
return UnionType(items + [NoneTyp()], t.line, t.column)
1037997
else:
1038998
return UnionType([t, NoneTyp()], t.line, t.column)
999+
1000+
1001+
class TypeVariableChecker(TypeTranslator):
1002+
"""Visitor that checks that type variables in generic types have valid values.
1003+
1004+
Note: This must be run at the end of semantic analysis when MROs are
1005+
complete and forward references have been resolved.
1006+
1007+
This does two things:
1008+
1009+
- If type variable in C has a value restriction, check that X in C[X] conforms
1010+
to the restriction.
1011+
- If type variable in C has a non-default upper bound, check that X in C[X]
1012+
conforms to the upper bound.
1013+
1014+
(This doesn't need to be a type translator, but it simplifies the implementation.)
1015+
"""
1016+
1017+
def __init__(self, fail: Callable[[str, Context], None]) -> None:
1018+
self.fail = fail
1019+
1020+
def visit_instance(self, t: Instance) -> Type:
1021+
info = t.type
1022+
for (i, arg), tvar in zip(enumerate(t.args), info.defn.type_vars):
1023+
if tvar.values:
1024+
if isinstance(arg, TypeVarType):
1025+
arg_values = arg.values
1026+
if not arg_values:
1027+
self.fail('Type variable "{}" not valid as type '
1028+
'argument value for "{}"'.format(
1029+
arg.name, info.name()), t)
1030+
continue
1031+
else:
1032+
arg_values = [arg]
1033+
self.check_type_var_values(info, arg_values, tvar.name, tvar.values, i + 1, t)
1034+
if not is_subtype(arg, tvar.upper_bound):
1035+
self.fail('Type argument "{}" of "{}" must be '
1036+
'a subtype of "{}"'.format(
1037+
arg, info.name(), tvar.upper_bound), t)
1038+
return t
1039+
1040+
def check_type_var_values(self, type: TypeInfo, actuals: List[Type], arg_name: str,
1041+
valids: List[Type], arg_number: int, context: Context) -> None:
1042+
for actual in actuals:
1043+
if (not isinstance(actual, AnyType) and
1044+
not any(is_same_type(actual, value)
1045+
for value in valids)):
1046+
if len(actuals) > 1 or not isinstance(actual, Instance):
1047+
self.fail('Invalid type argument value for "{}"'.format(
1048+
type.name()), context)
1049+
else:
1050+
class_name = '"{}"'.format(type.name())
1051+
actual_type_name = '"{}"'.format(actual.type.name())
1052+
self.fail(messages.INCOMPATIBLE_TYPEVAR_VALUE.format(
1053+
arg_name, class_name, actual_type_name), context)

test-data/unit/check-newtype.test

+7
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,10 @@ d: object
360360
if isinstance(d, T): # E: Cannot use isinstance() with a NewType type
361361
reveal_type(d) # E: Revealed type is '__main__.T'
362362
[builtins fixtures/isinstancelist.pyi]
363+
364+
[case testInvalidNewTypeCrash]
365+
from typing import List, NewType, Union
366+
N = NewType('N', XXX) # E: Argument 2 to NewType(...) must be subclassable (got "Any") \
367+
# E: Name 'XXX' is not defined
368+
x: List[Union[N, int]] # E: Invalid type "__main__.N"
369+
[builtins fixtures/list.pyi]

test-data/unit/check-typeddict.test

+1-1
Original file line numberDiff line numberDiff line change
@@ -1333,7 +1333,7 @@ T = TypeVar('T', bound='M')
13331333
class G(Generic[T]):
13341334
x: T
13351335

1336-
yb: G[int] # E: Type argument "builtins.int" of "G" must be a subtype of "TypedDict({'x': builtins.int}, fallback=typing.Mapping[builtins.str, builtins.object])"
1336+
yb: G[int] # E: Type argument "builtins.int" of "G" must be a subtype of "TypedDict('__main__.M', {'x': builtins.int})"
13371337
yg: G[M]
13381338
z: int = G[M]().x['x']
13391339

test-data/unit/check-unions.test

+7
Original file line numberDiff line numberDiff line change
@@ -940,3 +940,10 @@ x: Union[ExtremelyLongTypeNameWhichIsGenericSoWeCanUseItMultipleTimes[int],
940940
def takes_int(arg: int) -> None: pass
941941

942942
takes_int(x) # E: Argument 1 to "takes_int" has incompatible type <union: 6 items>; expected "int"
943+
944+
[case testRecursiveForwardReferenceInUnion]
945+
from typing import List, Union
946+
MYTYPE = List[Union[str, "MYTYPE"]]
947+
[builtins fixtures/list.pyi]
948+
[out]
949+
main:2: error: Recursive types not fully supported yet, nested types replaced with "Any"

0 commit comments

Comments
 (0)