Skip to content

Commit 3b97e6e

Browse files
authored
[PEP 695] Implement new scoping rules for type parameters (#17258)
Type parameters get a separate scope with some special features. Work on #15238.
1 parent 5fb8d62 commit 3b97e6e

File tree

4 files changed

+323
-45
lines changed

4 files changed

+323
-45
lines changed

mypy/nodes.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2502,7 +2502,7 @@ class TypeVarLikeExpr(SymbolNode, Expression):
25022502
Note that they are constructed by the semantic analyzer.
25032503
"""
25042504

2505-
__slots__ = ("_name", "_fullname", "upper_bound", "default", "variance")
2505+
__slots__ = ("_name", "_fullname", "upper_bound", "default", "variance", "is_new_style")
25062506

25072507
_name: str
25082508
_fullname: str
@@ -2525,13 +2525,15 @@ def __init__(
25252525
upper_bound: mypy.types.Type,
25262526
default: mypy.types.Type,
25272527
variance: int = INVARIANT,
2528+
is_new_style: bool = False,
25282529
) -> None:
25292530
super().__init__()
25302531
self._name = name
25312532
self._fullname = fullname
25322533
self.upper_bound = upper_bound
25332534
self.default = default
25342535
self.variance = variance
2536+
self.is_new_style = is_new_style
25352537

25362538
@property
25372539
def name(self) -> str:
@@ -2570,8 +2572,9 @@ def __init__(
25702572
upper_bound: mypy.types.Type,
25712573
default: mypy.types.Type,
25722574
variance: int = INVARIANT,
2575+
is_new_style: bool = False,
25732576
) -> None:
2574-
super().__init__(name, fullname, upper_bound, default, variance)
2577+
super().__init__(name, fullname, upper_bound, default, variance, is_new_style)
25752578
self.values = values
25762579

25772580
def accept(self, visitor: ExpressionVisitor[T]) -> T:
@@ -2648,8 +2651,9 @@ def __init__(
26482651
tuple_fallback: mypy.types.Instance,
26492652
default: mypy.types.Type,
26502653
variance: int = INVARIANT,
2654+
is_new_style: bool = False,
26512655
) -> None:
2652-
super().__init__(name, fullname, upper_bound, default, variance)
2656+
super().__init__(name, fullname, upper_bound, default, variance, is_new_style)
26532657
self.tuple_fallback = tuple_fallback
26542658

26552659
def accept(self, visitor: ExpressionVisitor[T]) -> T:

mypy/semanal.py

Lines changed: 90 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,14 @@
317317
CORE_BUILTIN_CLASSES: Final = ["object", "bool", "function"]
318318

319319

320+
# Python has several different scope/namespace kinds with subtly different semantics.
321+
SCOPE_GLOBAL: Final = 0 # Module top level
322+
SCOPE_CLASS: Final = 1 # Class body
323+
SCOPE_FUNC: Final = 2 # Function or lambda
324+
SCOPE_COMPREHENSION: Final = 3 # Comprehension or generator expression
325+
SCOPE_ANNOTATION: Final = 4 # Annotation scopes for type parameters and aliases (PEP 695)
326+
327+
320328
# Used for tracking incomplete references
321329
Tag: _TypeAlias = int
322330

@@ -342,8 +350,8 @@ class SemanticAnalyzer(
342350
nonlocal_decls: list[set[str]]
343351
# Local names of function scopes; None for non-function scopes.
344352
locals: list[SymbolTable | None]
345-
# Whether each scope is a comprehension scope.
346-
is_comprehension_stack: list[bool]
353+
# Type of each scope (SCOPE_*, indexes match locals)
354+
scope_stack: list[int]
347355
# Nested block depths of scopes
348356
block_depth: list[int]
349357
# TypeInfo of directly enclosing class (or None)
@@ -417,7 +425,7 @@ def __init__(
417425
errors: Report analysis errors using this instance
418426
"""
419427
self.locals = [None]
420-
self.is_comprehension_stack = [False]
428+
self.scope_stack = [SCOPE_GLOBAL]
421429
# Saved namespaces from previous iteration. Every top-level function/method body is
422430
# analyzed in several iterations until all names are resolved. We need to save
423431
# the local namespaces for the top level function and all nested functions between
@@ -880,6 +888,7 @@ def analyze_func_def(self, defn: FuncDef) -> None:
880888
# Don't store not ready types (including placeholders).
881889
if self.found_incomplete_ref(tag) or has_placeholder(result):
882890
self.defer(defn)
891+
# TODO: pop type args
883892
return
884893
assert isinstance(result, ProperType)
885894
if isinstance(result, CallableType):
@@ -1645,6 +1654,8 @@ def push_type_args(
16451654
) -> list[tuple[str, TypeVarLikeExpr]] | None:
16461655
if not type_args:
16471656
return []
1657+
self.locals.append(SymbolTable())
1658+
self.scope_stack.append(SCOPE_ANNOTATION)
16481659
tvs: list[tuple[str, TypeVarLikeExpr]] = []
16491660
for p in type_args:
16501661
tv = self.analyze_type_param(p)
@@ -1653,10 +1664,23 @@ def push_type_args(
16531664
tvs.append((p.name, tv))
16541665

16551666
for name, tv in tvs:
1656-
self.add_symbol(name, tv, context, no_progress=True)
1667+
if self.is_defined_type_param(name):
1668+
self.fail(f'"{name}" already defined as a type parameter', context)
1669+
else:
1670+
self.add_symbol(name, tv, context, no_progress=True, type_param=True)
16571671

16581672
return tvs
16591673

1674+
def is_defined_type_param(self, name: str) -> bool:
1675+
for names in self.locals:
1676+
if names is None:
1677+
continue
1678+
if name in names:
1679+
node = names[name].node
1680+
if isinstance(node, TypeVarLikeExpr):
1681+
return True
1682+
return False
1683+
16601684
def analyze_type_param(self, type_param: TypeParam) -> TypeVarLikeExpr | None:
16611685
fullname = self.qualified_name(type_param.name)
16621686
if type_param.upper_bound:
@@ -1681,10 +1705,15 @@ def analyze_type_param(self, type_param: TypeParam) -> TypeVarLikeExpr | None:
16811705
upper_bound=upper_bound,
16821706
default=default,
16831707
variance=VARIANCE_NOT_READY,
1708+
is_new_style=True,
16841709
)
16851710
elif type_param.kind == PARAM_SPEC_KIND:
16861711
return ParamSpecExpr(
1687-
name=type_param.name, fullname=fullname, upper_bound=upper_bound, default=default
1712+
name=type_param.name,
1713+
fullname=fullname,
1714+
upper_bound=upper_bound,
1715+
default=default,
1716+
is_new_style=True,
16881717
)
16891718
else:
16901719
assert type_param.kind == TYPE_VAR_TUPLE_KIND
@@ -1696,14 +1725,14 @@ def analyze_type_param(self, type_param: TypeParam) -> TypeVarLikeExpr | None:
16961725
upper_bound=tuple_fallback.copy_modified(),
16971726
tuple_fallback=tuple_fallback,
16981727
default=default,
1728+
is_new_style=True,
16991729
)
17001730

17011731
def pop_type_args(self, type_args: list[TypeParam] | None) -> None:
17021732
if not type_args:
17031733
return
1704-
for tv in type_args:
1705-
names = self.current_symbol_table()
1706-
del names[tv.name]
1734+
self.locals.pop()
1735+
self.scope_stack.pop()
17071736

17081737
def analyze_class(self, defn: ClassDef) -> None:
17091738
fullname = self.qualified_name(defn.name)
@@ -1785,8 +1814,18 @@ def analyze_class(self, defn: ClassDef) -> None:
17851814
defn.info.is_protocol = is_protocol
17861815
self.recalculate_metaclass(defn, declared_metaclass)
17871816
defn.info.runtime_protocol = False
1817+
1818+
if defn.type_args:
1819+
# PEP 695 type parameters are not in scope in class decorators, so
1820+
# temporarily disable type parameter namespace.
1821+
type_params_names = self.locals.pop()
1822+
self.scope_stack.pop()
17881823
for decorator in defn.decorators:
17891824
self.analyze_class_decorator(defn, decorator)
1825+
if defn.type_args:
1826+
self.locals.append(type_params_names)
1827+
self.scope_stack.append(SCOPE_ANNOTATION)
1828+
17901829
self.analyze_class_body_common(defn)
17911830

17921831
def setup_type_vars(self, defn: ClassDef, tvar_defs: list[TypeVarLikeType]) -> None:
@@ -1938,7 +1977,7 @@ def enter_class(self, info: TypeInfo) -> None:
19381977
# Remember previous active class
19391978
self.type_stack.append(self.type)
19401979
self.locals.append(None) # Add class scope
1941-
self.is_comprehension_stack.append(False)
1980+
self.scope_stack.append(SCOPE_CLASS)
19421981
self.block_depth.append(-1) # The class body increments this to 0
19431982
self.loop_depth.append(0)
19441983
self._type = info
@@ -1949,7 +1988,7 @@ def leave_class(self) -> None:
19491988
self.block_depth.pop()
19501989
self.loop_depth.pop()
19511990
self.locals.pop()
1952-
self.is_comprehension_stack.pop()
1991+
self.scope_stack.pop()
19531992
self._type = self.type_stack.pop()
19541993
self.missing_names.pop()
19551994

@@ -2923,8 +2962,8 @@ class C:
29232962
[(j := i) for i in [1, 2, 3]]
29242963
is a syntax error that is not enforced by Python parser, but at later steps.
29252964
"""
2926-
for i, is_comprehension in enumerate(reversed(self.is_comprehension_stack)):
2927-
if not is_comprehension and i < len(self.locals) - 1:
2965+
for i, scope_type in enumerate(reversed(self.scope_stack)):
2966+
if scope_type != SCOPE_COMPREHENSION and i < len(self.locals) - 1:
29282967
if self.locals[-1 - i] is None:
29292968
self.fail(
29302969
"Assignment expression within a comprehension"
@@ -5188,8 +5227,14 @@ def visit_nonlocal_decl(self, d: NonlocalDecl) -> None:
51885227
self.fail("nonlocal declaration not allowed at module level", d)
51895228
else:
51905229
for name in d.names:
5191-
for table in reversed(self.locals[:-1]):
5230+
for table, scope_type in zip(
5231+
reversed(self.locals[:-1]), reversed(self.scope_stack[:-1])
5232+
):
51925233
if table is not None and name in table:
5234+
if scope_type == SCOPE_ANNOTATION:
5235+
self.fail(
5236+
f'nonlocal binding not allowed for type parameter "{name}"', d
5237+
)
51935238
break
51945239
else:
51955240
self.fail(f'No binding for nonlocal "{name}" found', d)
@@ -5350,7 +5395,7 @@ def visit_star_expr(self, expr: StarExpr) -> None:
53505395
def visit_yield_from_expr(self, e: YieldFromExpr) -> None:
53515396
if not self.is_func_scope():
53525397
self.fail('"yield from" outside function', e, serious=True, blocker=True)
5353-
elif self.is_comprehension_stack[-1]:
5398+
elif self.scope_stack[-1] == SCOPE_COMPREHENSION:
53545399
self.fail(
53555400
'"yield from" inside comprehension or generator expression',
53565401
e,
@@ -5848,7 +5893,7 @@ def visit__promote_expr(self, expr: PromoteExpr) -> None:
58485893
def visit_yield_expr(self, e: YieldExpr) -> None:
58495894
if not self.is_func_scope():
58505895
self.fail('"yield" outside function', e, serious=True, blocker=True)
5851-
elif self.is_comprehension_stack[-1]:
5896+
elif self.scope_stack[-1] == SCOPE_COMPREHENSION:
58525897
self.fail(
58535898
'"yield" inside comprehension or generator expression',
58545899
e,
@@ -6281,6 +6326,7 @@ def add_symbol(
62816326
can_defer: bool = True,
62826327
escape_comprehensions: bool = False,
62836328
no_progress: bool = False,
6329+
type_param: bool = False,
62846330
) -> bool:
62856331
"""Add symbol to the currently active symbol table.
62866332
@@ -6303,7 +6349,7 @@ def add_symbol(
63036349
kind, node, module_public=module_public, module_hidden=module_hidden
63046350
)
63056351
return self.add_symbol_table_node(
6306-
name, symbol, context, can_defer, escape_comprehensions, no_progress
6352+
name, symbol, context, can_defer, escape_comprehensions, no_progress, type_param
63076353
)
63086354

63096355
def add_symbol_skip_local(self, name: str, node: SymbolNode) -> None:
@@ -6336,6 +6382,7 @@ def add_symbol_table_node(
63366382
can_defer: bool = True,
63376383
escape_comprehensions: bool = False,
63386384
no_progress: bool = False,
6385+
type_param: bool = False,
63396386
) -> bool:
63406387
"""Add symbol table node to the currently active symbol table.
63416388
@@ -6355,7 +6402,9 @@ def add_symbol_table_node(
63556402
can_defer: if True, defer current target if adding a placeholder
63566403
context: error context (see above about None value)
63576404
"""
6358-
names = self.current_symbol_table(escape_comprehensions=escape_comprehensions)
6405+
names = self.current_symbol_table(
6406+
escape_comprehensions=escape_comprehensions, type_param=type_param
6407+
)
63596408
existing = names.get(name)
63606409
if isinstance(symbol.node, PlaceholderNode) and can_defer:
63616410
if context is not None:
@@ -6673,7 +6722,7 @@ def enter(
66736722
names = self.saved_locals.setdefault(function, SymbolTable())
66746723
self.locals.append(names)
66756724
is_comprehension = isinstance(function, (GeneratorExpr, DictionaryComprehension))
6676-
self.is_comprehension_stack.append(is_comprehension)
6725+
self.scope_stack.append(SCOPE_FUNC if not is_comprehension else SCOPE_COMPREHENSION)
66776726
self.global_decls.append(set())
66786727
self.nonlocal_decls.append(set())
66796728
# -1 since entering block will increment this to 0.
@@ -6684,19 +6733,22 @@ def enter(
66846733
yield
66856734
finally:
66866735
self.locals.pop()
6687-
self.is_comprehension_stack.pop()
6736+
self.scope_stack.pop()
66886737
self.global_decls.pop()
66896738
self.nonlocal_decls.pop()
66906739
self.block_depth.pop()
66916740
self.loop_depth.pop()
66926741
self.missing_names.pop()
66936742

66946743
def is_func_scope(self) -> bool:
6695-
return self.locals[-1] is not None
6744+
scope_type = self.scope_stack[-1]
6745+
if scope_type == SCOPE_ANNOTATION:
6746+
scope_type = self.scope_stack[-2]
6747+
return scope_type in (SCOPE_FUNC, SCOPE_COMPREHENSION)
66966748

66976749
def is_nested_within_func_scope(self) -> bool:
66986750
"""Are we underneath a function scope, even if we are in a nested class also?"""
6699-
return any(l is not None for l in self.locals)
6751+
return any(s in (SCOPE_FUNC, SCOPE_COMPREHENSION) for s in self.scope_stack)
67006752

67016753
def is_class_scope(self) -> bool:
67026754
return self.type is not None and not self.is_func_scope()
@@ -6713,14 +6765,24 @@ def current_symbol_kind(self) -> int:
67136765
kind = GDEF
67146766
return kind
67156767

6716-
def current_symbol_table(self, escape_comprehensions: bool = False) -> SymbolTable:
6717-
if self.is_func_scope():
6718-
assert self.locals[-1] is not None
6768+
def current_symbol_table(
6769+
self, escape_comprehensions: bool = False, type_param: bool = False
6770+
) -> SymbolTable:
6771+
if type_param and self.scope_stack[-1] == SCOPE_ANNOTATION:
6772+
n = self.locals[-1]
6773+
assert n is not None
6774+
return n
6775+
elif self.is_func_scope():
6776+
if self.scope_stack[-1] == SCOPE_ANNOTATION:
6777+
n = self.locals[-2]
6778+
else:
6779+
n = self.locals[-1]
6780+
assert n is not None
67196781
if escape_comprehensions:
6720-
assert len(self.locals) == len(self.is_comprehension_stack)
6782+
assert len(self.locals) == len(self.scope_stack)
67216783
# Retrieve the symbol table from the enclosing non-comprehension scope.
6722-
for i, is_comprehension in enumerate(reversed(self.is_comprehension_stack)):
6723-
if not is_comprehension:
6784+
for i, scope_type in enumerate(reversed(self.scope_stack)):
6785+
if scope_type != SCOPE_COMPREHENSION:
67246786
if i == len(self.locals) - 1: # The last iteration.
67256787
# The caller of the comprehension is in the global space.
67266788
names = self.globals
@@ -6734,7 +6796,7 @@ def current_symbol_table(self, escape_comprehensions: bool = False) -> SymbolTab
67346796
else:
67356797
assert False, "Should have at least one non-comprehension scope"
67366798
else:
6737-
names = self.locals[-1]
6799+
names = n
67386800
assert names is not None
67396801
elif self.type is not None:
67406802
names = self.type.names

0 commit comments

Comments
 (0)