Skip to content

Commit a568f3a

Browse files
authored
Add foundation for TypeVar defaults (PEP 696) (#14872)
Start implementing [PEP 696](https://peps.python.org/pep-0696/) TypeVar defaults. This PR * Adds a `default` parameter to `TypeVarLikeExpr` and `TypeVarLikeType`. * Updates most visitors to account for the new `default` parameter. * Update existing calls to add value for `default` => `AnyType(TypeOfAny.from_omitted_generics)`. A followup PR will update the semantic analyzer and add basic tests for `TypeVar`, `ParamSpec`, and `TypeVarTuple` calls with a `default` argument. -> #14873 Ref #14851
1 parent 7fe1fdd commit a568f3a

23 files changed

+356
-84
lines changed

mypy/checker.py

+1
Original file line numberDiff line numberDiff line change
@@ -7191,6 +7191,7 @@ def detach_callable(typ: CallableType) -> CallableType:
71917191
id=var.id,
71927192
values=var.values,
71937193
upper_bound=var.upper_bound,
7194+
default=var.default,
71947195
variance=var.variance,
71957196
)
71967197
)

mypy/checkexpr.py

+60-6
Original file line numberDiff line numberDiff line change
@@ -4189,7 +4189,14 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
41894189
# Used for list and set expressions, as well as for tuples
41904190
# containing star expressions that don't refer to a
41914191
# Tuple. (Note: "lst" stands for list-set-tuple. :-)
4192-
tv = TypeVarType("T", "T", id=-1, values=[], upper_bound=self.object_type())
4192+
tv = TypeVarType(
4193+
"T",
4194+
"T",
4195+
id=-1,
4196+
values=[],
4197+
upper_bound=self.object_type(),
4198+
default=AnyType(TypeOfAny.from_omitted_generics),
4199+
)
41934200
constructor = CallableType(
41944201
[tv],
41954202
[nodes.ARG_STAR],
@@ -4357,8 +4364,22 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
43574364
return dt
43584365

43594366
# Define type variables (used in constructors below).
4360-
kt = TypeVarType("KT", "KT", id=-1, values=[], upper_bound=self.object_type())
4361-
vt = TypeVarType("VT", "VT", id=-2, values=[], upper_bound=self.object_type())
4367+
kt = TypeVarType(
4368+
"KT",
4369+
"KT",
4370+
id=-1,
4371+
values=[],
4372+
upper_bound=self.object_type(),
4373+
default=AnyType(TypeOfAny.from_omitted_generics),
4374+
)
4375+
vt = TypeVarType(
4376+
"VT",
4377+
"VT",
4378+
id=-2,
4379+
values=[],
4380+
upper_bound=self.object_type(),
4381+
default=AnyType(TypeOfAny.from_omitted_generics),
4382+
)
43624383

43634384
# Collect function arguments, watching out for **expr.
43644385
args: list[Expression] = []
@@ -4722,7 +4743,14 @@ def check_generator_or_comprehension(
47224743

47234744
# Infer the type of the list comprehension by using a synthetic generic
47244745
# callable type.
4725-
tv = TypeVarType("T", "T", id=-1, values=[], upper_bound=self.object_type())
4746+
tv = TypeVarType(
4747+
"T",
4748+
"T",
4749+
id=-1,
4750+
values=[],
4751+
upper_bound=self.object_type(),
4752+
default=AnyType(TypeOfAny.from_omitted_generics),
4753+
)
47264754
tv_list: list[Type] = [tv]
47274755
constructor = CallableType(
47284756
tv_list,
@@ -4742,8 +4770,22 @@ def visit_dictionary_comprehension(self, e: DictionaryComprehension) -> Type:
47424770

47434771
# Infer the type of the list comprehension by using a synthetic generic
47444772
# callable type.
4745-
ktdef = TypeVarType("KT", "KT", id=-1, values=[], upper_bound=self.object_type())
4746-
vtdef = TypeVarType("VT", "VT", id=-2, values=[], upper_bound=self.object_type())
4773+
ktdef = TypeVarType(
4774+
"KT",
4775+
"KT",
4776+
id=-1,
4777+
values=[],
4778+
upper_bound=self.object_type(),
4779+
default=AnyType(TypeOfAny.from_omitted_generics),
4780+
)
4781+
vtdef = TypeVarType(
4782+
"VT",
4783+
"VT",
4784+
id=-2,
4785+
values=[],
4786+
upper_bound=self.object_type(),
4787+
default=AnyType(TypeOfAny.from_omitted_generics),
4788+
)
47474789
constructor = CallableType(
47484790
[ktdef, vtdef],
47494791
[nodes.ARG_POS, nodes.ARG_POS],
@@ -5264,6 +5306,18 @@ def visit_callable_type(self, t: CallableType) -> bool:
52645306
return False
52655307
return super().visit_callable_type(t)
52665308

5309+
def visit_type_var(self, t: TypeVarType) -> bool:
5310+
default = [t.default] if t.has_default() else []
5311+
return self.query_types([t.upper_bound, *default] + t.values)
5312+
5313+
def visit_param_spec(self, t: ParamSpecType) -> bool:
5314+
default = [t.default] if t.has_default() else []
5315+
return self.query_types([t.upper_bound, *default])
5316+
5317+
def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool:
5318+
default = [t.default] if t.has_default() else []
5319+
return self.query_types([t.upper_bound, *default])
5320+
52675321

52685322
def has_coroutine_decorator(t: Type) -> bool:
52695323
"""Whether t came from a function decorated with `@coroutine`."""

mypy/copytype.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def visit_type_var(self, t: TypeVarType) -> ProperType:
7272
return self.copy_common(t, t.copy_modified())
7373

7474
def visit_param_spec(self, t: ParamSpecType) -> ProperType:
75-
dup = ParamSpecType(t.name, t.fullname, t.id, t.flavor, t.upper_bound, prefix=t.prefix)
75+
dup = ParamSpecType(
76+
t.name, t.fullname, t.id, t.flavor, t.upper_bound, t.default, prefix=t.prefix
77+
)
7678
return self.copy_common(t, dup)
7779

7880
def visit_parameters(self, t: Parameters) -> ProperType:
@@ -86,7 +88,9 @@ def visit_parameters(self, t: Parameters) -> ProperType:
8688
return self.copy_common(t, dup)
8789

8890
def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
89-
dup = TypeVarTupleType(t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback)
91+
dup = TypeVarTupleType(
92+
t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback, t.default
93+
)
9094
return self.copy_common(t, dup)
9195

9296
def visit_unpack_type(self, t: UnpackType) -> ProperType:

mypy/expandtype.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,7 @@ def freshen_function_type_vars(callee: F) -> F:
130130
tvs = []
131131
tvmap: dict[TypeVarId, Type] = {}
132132
for v in callee.variables:
133-
if isinstance(v, TypeVarType):
134-
tv: TypeVarLikeType = TypeVarType.new_unification_variable(v)
135-
elif isinstance(v, TypeVarTupleType):
136-
assert isinstance(v, TypeVarTupleType)
137-
tv = TypeVarTupleType.new_unification_variable(v)
138-
else:
139-
assert isinstance(v, ParamSpecType)
140-
tv = ParamSpecType.new_unification_variable(v)
133+
tv = v.new_unification_variable(v)
141134
tvs.append(tv)
142135
tvmap[v.id] = tv
143136
fresh = expand_type(callee, tvmap).copy_modified(variables=tvs)

mypy/fixup.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -171,17 +171,21 @@ def visit_class_def(self, c: ClassDef) -> None:
171171
for value in v.values:
172172
value.accept(self.type_fixer)
173173
v.upper_bound.accept(self.type_fixer)
174+
v.default.accept(self.type_fixer)
174175

175176
def visit_type_var_expr(self, tv: TypeVarExpr) -> None:
176177
for value in tv.values:
177178
value.accept(self.type_fixer)
178179
tv.upper_bound.accept(self.type_fixer)
180+
tv.default.accept(self.type_fixer)
179181

180182
def visit_paramspec_expr(self, p: ParamSpecExpr) -> None:
181183
p.upper_bound.accept(self.type_fixer)
184+
p.default.accept(self.type_fixer)
182185

183186
def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None:
184187
tv.upper_bound.accept(self.type_fixer)
188+
tv.default.accept(self.type_fixer)
185189

186190
def visit_var(self, v: Var) -> None:
187191
if self.current_info is not None:
@@ -303,14 +307,16 @@ def visit_type_var(self, tvt: TypeVarType) -> None:
303307
if tvt.values:
304308
for vt in tvt.values:
305309
vt.accept(self)
306-
if tvt.upper_bound is not None:
307-
tvt.upper_bound.accept(self)
310+
tvt.upper_bound.accept(self)
311+
tvt.default.accept(self)
308312

309313
def visit_param_spec(self, p: ParamSpecType) -> None:
310314
p.upper_bound.accept(self)
315+
p.default.accept(self)
311316

312317
def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
313318
t.upper_bound.accept(self)
319+
t.default.accept(self)
314320

315321
def visit_unpack_type(self, u: UnpackType) -> None:
316322
u.type.accept(self)

mypy/indirection.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ def visit_deleted_type(self, t: types.DeletedType) -> set[str]:
6464
return set()
6565

6666
def visit_type_var(self, t: types.TypeVarType) -> set[str]:
67-
return self._visit(t.values) | self._visit(t.upper_bound)
67+
return self._visit(t.values) | self._visit(t.upper_bound) | self._visit(t.default)
6868

6969
def visit_param_spec(self, t: types.ParamSpecType) -> set[str]:
70-
return set()
70+
return self._visit(t.upper_bound) | self._visit(t.default)
7171

7272
def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> set[str]:
73-
return self._visit(t.upper_bound)
73+
return self._visit(t.upper_bound) | self._visit(t.default)
7474

7575
def visit_unpack_type(self, t: types.UnpackType) -> set[str]:
7676
return t.type.accept(self)

mypy/nodes.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -2439,26 +2439,35 @@ class TypeVarLikeExpr(SymbolNode, Expression):
24392439
Note that they are constructed by the semantic analyzer.
24402440
"""
24412441

2442-
__slots__ = ("_name", "_fullname", "upper_bound", "variance")
2442+
__slots__ = ("_name", "_fullname", "upper_bound", "default", "variance")
24432443

24442444
_name: str
24452445
_fullname: str
24462446
# Upper bound: only subtypes of upper_bound are valid as values. By default
24472447
# this is 'object', meaning no restriction.
24482448
upper_bound: mypy.types.Type
2449+
# Default: used to resolve the TypeVar if the default is not explicitly given.
2450+
# By default this is 'AnyType(TypeOfAny.from_omitted_generics)'. See PEP 696.
2451+
default: mypy.types.Type
24492452
# Variance of the type variable. Invariant is the default.
24502453
# TypeVar(..., covariant=True) defines a covariant type variable.
24512454
# TypeVar(..., contravariant=True) defines a contravariant type
24522455
# variable.
24532456
variance: int
24542457

24552458
def __init__(
2456-
self, name: str, fullname: str, upper_bound: mypy.types.Type, variance: int = INVARIANT
2459+
self,
2460+
name: str,
2461+
fullname: str,
2462+
upper_bound: mypy.types.Type,
2463+
default: mypy.types.Type,
2464+
variance: int = INVARIANT,
24572465
) -> None:
24582466
super().__init__()
24592467
self._name = name
24602468
self._fullname = fullname
24612469
self.upper_bound = upper_bound
2470+
self.default = default
24622471
self.variance = variance
24632472

24642473
@property
@@ -2496,9 +2505,10 @@ def __init__(
24962505
fullname: str,
24972506
values: list[mypy.types.Type],
24982507
upper_bound: mypy.types.Type,
2508+
default: mypy.types.Type,
24992509
variance: int = INVARIANT,
25002510
) -> None:
2501-
super().__init__(name, fullname, upper_bound, variance)
2511+
super().__init__(name, fullname, upper_bound, default, variance)
25022512
self.values = values
25032513

25042514
def accept(self, visitor: ExpressionVisitor[T]) -> T:
@@ -2511,6 +2521,7 @@ def serialize(self) -> JsonDict:
25112521
"fullname": self._fullname,
25122522
"values": [t.serialize() for t in self.values],
25132523
"upper_bound": self.upper_bound.serialize(),
2524+
"default": self.default.serialize(),
25142525
"variance": self.variance,
25152526
}
25162527

@@ -2522,6 +2533,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarExpr:
25222533
data["fullname"],
25232534
[mypy.types.deserialize_type(v) for v in data["values"]],
25242535
mypy.types.deserialize_type(data["upper_bound"]),
2536+
mypy.types.deserialize_type(data["default"]),
25252537
data["variance"],
25262538
)
25272539

@@ -2540,6 +2552,7 @@ def serialize(self) -> JsonDict:
25402552
"name": self._name,
25412553
"fullname": self._fullname,
25422554
"upper_bound": self.upper_bound.serialize(),
2555+
"default": self.default.serialize(),
25432556
"variance": self.variance,
25442557
}
25452558

@@ -2550,6 +2563,7 @@ def deserialize(cls, data: JsonDict) -> ParamSpecExpr:
25502563
data["name"],
25512564
data["fullname"],
25522565
mypy.types.deserialize_type(data["upper_bound"]),
2566+
mypy.types.deserialize_type(data["default"]),
25532567
data["variance"],
25542568
)
25552569

@@ -2569,9 +2583,10 @@ def __init__(
25692583
fullname: str,
25702584
upper_bound: mypy.types.Type,
25712585
tuple_fallback: mypy.types.Instance,
2586+
default: mypy.types.Type,
25722587
variance: int = INVARIANT,
25732588
) -> None:
2574-
super().__init__(name, fullname, upper_bound, variance)
2589+
super().__init__(name, fullname, upper_bound, default, variance)
25752590
self.tuple_fallback = tuple_fallback
25762591

25772592
def accept(self, visitor: ExpressionVisitor[T]) -> T:
@@ -2584,6 +2599,7 @@ def serialize(self) -> JsonDict:
25842599
"fullname": self._fullname,
25852600
"upper_bound": self.upper_bound.serialize(),
25862601
"tuple_fallback": self.tuple_fallback.serialize(),
2602+
"default": self.default.serialize(),
25872603
"variance": self.variance,
25882604
}
25892605

@@ -2595,6 +2611,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarTupleExpr:
25952611
data["fullname"],
25962612
mypy.types.deserialize_type(data["upper_bound"]),
25972613
mypy.types.Instance.deserialize(data["tuple_fallback"]),
2614+
mypy.types.deserialize_type(data["default"]),
25982615
data["variance"],
25992616
)
26002617

mypy/plugins/attrs.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -772,9 +772,14 @@ def _add_order(ctx: mypy.plugin.ClassDefContext, adder: MethodAdder) -> None:
772772
id=-1,
773773
values=[],
774774
upper_bound=object_type,
775+
default=AnyType(TypeOfAny.from_omitted_generics),
775776
)
776777
self_tvar_expr = TypeVarExpr(
777-
SELF_TVAR_NAME, ctx.cls.info.fullname + "." + SELF_TVAR_NAME, [], object_type
778+
SELF_TVAR_NAME,
779+
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
780+
[],
781+
object_type,
782+
AnyType(TypeOfAny.from_omitted_generics),
778783
)
779784
ctx.cls.info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)
780785

mypy/plugins/dataclasses.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,11 @@ def transform(self) -> bool:
254254
# Type variable for self types in generated methods.
255255
obj_type = self._api.named_type("builtins.object")
256256
self_tvar_expr = TypeVarExpr(
257-
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, [], obj_type
257+
SELF_TVAR_NAME,
258+
info.fullname + "." + SELF_TVAR_NAME,
259+
[],
260+
obj_type,
261+
AnyType(TypeOfAny.from_omitted_generics),
258262
)
259263
info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)
260264

@@ -273,6 +277,7 @@ def transform(self) -> bool:
273277
id=-1,
274278
values=[],
275279
upper_bound=obj_type,
280+
default=AnyType(TypeOfAny.from_omitted_generics),
276281
)
277282
order_return_type = self._api.named_type("builtins.bool")
278283
order_args = [

0 commit comments

Comments
 (0)