Skip to content

Commit ce3975d

Browse files
authored
Basic support for ParamSpec type checking (#11594)
Add support for type checking several ParamSpec use cases (PEP 612). @hauntsaninja previously added support for semantic analysis of ParamSpec definitions, and this builds on top that foundation. The implementation has these main things going on: * `ParamSpecType` that is similar to `TypeVarType` but has three "flavors" that correspond to `P`, `P.args` and `P.kwargs` * `CallableType` represents `Callable[P, T]` if the arguments are (`*args: P.args`, `**kwargs: P.kwargs`) -- and more generally, there can also be arbitrary additional prefix arguments * Type variables of functions and classes can now be represented using `ParamSpecType` in addition to `TypeVarType` There are still a bunch of TODOs. Some of these are important to address before the release that includes this. I believe that this is good enough to merge and remaining issues can be fixed in follow-up PRs. Notable missing features include these: * `Concatenate` * Specifying the value of ParamSpec explicitly (e.g. `Z[[int, str, bool]]`) * Various validity checks -- currently only some errors are caught * Special case of decorating a method (python/typeshed#6347) * `atexit.register(lambda: ...)` generates an error
1 parent 9c86f9b commit ce3975d

29 files changed

+829
-197
lines changed

misc/proper_plugin.py

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def is_special_target(right: ProperType) -> bool:
6666
if right.type_object().fullname in (
6767
'mypy.types.UnboundType',
6868
'mypy.types.TypeVarType',
69+
'mypy.types.ParamSpecType',
6970
'mypy.types.RawExpressionType',
7071
'mypy.types.EllipsisType',
7172
'mypy.types.StarType',

mypy/applytype.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mypy.expandtype import expand_type
66
from mypy.types import (
77
Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType, get_proper_types,
8-
TypeVarLikeType, ProperType, ParamSpecType
8+
TypeVarLikeType, ProperType, ParamSpecType, get_proper_type
99
)
1010
from mypy.nodes import Context
1111

@@ -18,9 +18,8 @@ def get_target_type(
1818
context: Context,
1919
skip_unsatisfied: bool
2020
) -> Optional[Type]:
21-
# TODO(PEP612): fix for ParamSpecType
2221
if isinstance(tvar, ParamSpecType):
23-
return None
22+
return type
2423
assert isinstance(tvar, TypeVarType)
2524
values = get_proper_types(tvar.values)
2625
if values:
@@ -90,6 +89,14 @@ def apply_generic_arguments(
9089
if target_type is not None:
9190
id_to_type[tvar.id] = target_type
9291

92+
param_spec = callable.param_spec()
93+
if param_spec is not None:
94+
nt = id_to_type.get(param_spec.id)
95+
if nt is not None:
96+
nt = get_proper_type(nt)
97+
if isinstance(nt, CallableType):
98+
callable = callable.expand_param_spec(nt)
99+
93100
# Apply arguments to argument types.
94101
arg_types = [expand_type(at, id_to_type) for at in callable.arg_types]
95102

mypy/checker.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType,
3838
is_named_instance, union_items, TypeQuery, LiteralType,
3939
is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType,
40-
get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType)
40+
get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType, ParamSpecType
41+
)
4142
from mypy.sametypes import is_same_type
4243
from mypy.messages import (
4344
MessageBuilder, make_inferred_type_note, append_invariance_notes, pretty_seq,
@@ -976,13 +977,15 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
976977
ctx = typ
977978
self.fail(message_registry.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT, ctx)
978979
if typ.arg_kinds[i] == nodes.ARG_STAR:
979-
# builtins.tuple[T] is typing.Tuple[T, ...]
980-
arg_type = self.named_generic_type('builtins.tuple',
981-
[arg_type])
980+
if not isinstance(arg_type, ParamSpecType):
981+
# builtins.tuple[T] is typing.Tuple[T, ...]
982+
arg_type = self.named_generic_type('builtins.tuple',
983+
[arg_type])
982984
elif typ.arg_kinds[i] == nodes.ARG_STAR2:
983-
arg_type = self.named_generic_type('builtins.dict',
984-
[self.str_type(),
985-
arg_type])
985+
if not isinstance(arg_type, ParamSpecType):
986+
arg_type = self.named_generic_type('builtins.dict',
987+
[self.str_type(),
988+
arg_type])
986989
item.arguments[i].variable.type = arg_type
987990

988991
# Type check initialization expressions.
@@ -1883,7 +1886,7 @@ def check_protocol_variance(self, defn: ClassDef) -> None:
18831886
expected = CONTRAVARIANT
18841887
else:
18851888
expected = INVARIANT
1886-
if expected != tvar.variance:
1889+
if isinstance(tvar, TypeVarType) and expected != tvar.variance:
18871890
self.msg.bad_proto_variance(tvar.variance, tvar.name, expected, defn)
18881891

18891892
def check_multiple_inheritance(self, typ: TypeInfo) -> None:

mypy/checkexpr.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Type, AnyType, CallableType, Overloaded, NoneType, TypeVarType,
1919
TupleType, TypedDictType, Instance, ErasedType, UnionType,
2020
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue,
21-
is_named_instance, FunctionLike, ParamSpecType,
21+
is_named_instance, FunctionLike, ParamSpecType, ParamSpecFlavor,
2222
StarType, is_optional, remove_optional, is_generic_instance, get_proper_type, ProperType,
2323
get_proper_types, flatten_nested_unions
2424
)
@@ -1025,11 +1025,31 @@ def check_callable_call(self,
10251025
lambda i: self.accept(args[i]))
10261026

10271027
if callee.is_generic():
1028+
need_refresh = any(isinstance(v, ParamSpecType) for v in callee.variables)
10281029
callee = freshen_function_type_vars(callee)
10291030
callee = self.infer_function_type_arguments_using_context(
10301031
callee, context)
10311032
callee = self.infer_function_type_arguments(
10321033
callee, args, arg_kinds, formal_to_actual, context)
1034+
if need_refresh:
1035+
# Argument kinds etc. may have changed; recalculate actual-to-formal map
1036+
formal_to_actual = map_actuals_to_formals(
1037+
arg_kinds, arg_names,
1038+
callee.arg_kinds, callee.arg_names,
1039+
lambda i: self.accept(args[i]))
1040+
1041+
param_spec = callee.param_spec()
1042+
if param_spec is not None and arg_kinds == [ARG_STAR, ARG_STAR2]:
1043+
arg1 = get_proper_type(self.accept(args[0]))
1044+
arg2 = get_proper_type(self.accept(args[1]))
1045+
if (is_named_instance(arg1, 'builtins.tuple')
1046+
and is_named_instance(arg2, 'builtins.dict')):
1047+
assert isinstance(arg1, Instance)
1048+
assert isinstance(arg2, Instance)
1049+
if (isinstance(arg1.args[0], ParamSpecType)
1050+
and isinstance(arg2.args[1], ParamSpecType)):
1051+
# TODO: Check ParamSpec ids and flavors
1052+
return callee.ret_type, callee
10331053

10341054
arg_types = self.infer_arg_types_in_context(
10351055
callee, args, arg_kinds, formal_to_actual)
@@ -3981,15 +4001,18 @@ def is_valid_var_arg(self, typ: Type) -> bool:
39814001
return (isinstance(typ, TupleType) or
39824002
is_subtype(typ, self.chk.named_generic_type('typing.Iterable',
39834003
[AnyType(TypeOfAny.special_form)])) or
3984-
isinstance(typ, AnyType))
4004+
isinstance(typ, AnyType) or
4005+
(isinstance(typ, ParamSpecType) and typ.flavor == ParamSpecFlavor.ARGS))
39854006

39864007
def is_valid_keyword_var_arg(self, typ: Type) -> bool:
39874008
"""Is a type valid as a **kwargs argument?"""
39884009
ret = (
39894010
is_subtype(typ, self.chk.named_generic_type('typing.Mapping',
39904011
[self.named_type('builtins.str'), AnyType(TypeOfAny.special_form)])) or
39914012
is_subtype(typ, self.chk.named_generic_type('typing.Mapping',
3992-
[UninhabitedType(), UninhabitedType()])))
4013+
[UninhabitedType(), UninhabitedType()])) or
4014+
(isinstance(typ, ParamSpecType) and typ.flavor == ParamSpecFlavor.KWARGS)
4015+
)
39934016
if self.chk.options.python_version[0] < 3:
39944017
ret = ret or is_subtype(typ, self.chk.named_generic_type('typing.Mapping',
39954018
[self.named_type('builtins.unicode'), AnyType(TypeOfAny.special_form)]))

mypy/checkmember.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mypy.types import (
77
Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike,
88
TypeVarLikeType, Overloaded, TypeVarType, UnionType, PartialType, TypeOfAny, LiteralType,
9-
DeletedType, NoneType, TypeType, has_type_vars, get_proper_type, ProperType
9+
DeletedType, NoneType, TypeType, has_type_vars, get_proper_type, ProperType, ParamSpecType
1010
)
1111
from mypy.nodes import (
1212
TypeInfo, FuncBase, Var, FuncDef, SymbolNode, SymbolTable, Context,
@@ -669,6 +669,9 @@ def f(self: S) -> T: ...
669669
selfarg = item.arg_types[0]
670670
if subtypes.is_subtype(dispatched_arg_type, erase_typevars(erase_to_bound(selfarg))):
671671
new_items.append(item)
672+
elif isinstance(selfarg, ParamSpecType):
673+
# TODO: This is not always right. What's the most reasonable thing to do here?
674+
new_items.append(item)
672675
if not new_items:
673676
# Choose first item for the message (it may be not very helpful for overloads).
674677
msg.incompatible_self_argument(name, dispatched_arg_type, items[0],

mypy/constraints.py

+56-31
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Instance,
88
TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType,
99
UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType,
10-
ProperType, get_proper_type, TypeAliasType, is_union_with_any
10+
ProperType, ParamSpecType, get_proper_type, TypeAliasType, is_union_with_any,
11+
callable_with_ellipsis
1112
)
1213
from mypy.maptype import map_instance_to_supertype
1314
import mypy.subtypes
@@ -398,6 +399,10 @@ def visit_type_var(self, template: TypeVarType) -> List[Constraint]:
398399
assert False, ("Unexpected TypeVarType in ConstraintBuilderVisitor"
399400
" (should have been handled in infer_constraints)")
400401

402+
def visit_param_spec(self, template: ParamSpecType) -> List[Constraint]:
403+
# Can't infer ParamSpecs from component values (only via Callable[P, T]).
404+
return []
405+
401406
# Non-leaf types
402407

403408
def visit_instance(self, template: Instance) -> List[Constraint]:
@@ -438,14 +443,16 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
438443
# N.B: We use zip instead of indexing because the lengths might have
439444
# mismatches during daemon reprocessing.
440445
for tvar, mapped_arg, instance_arg in zip(tvars, mapped.args, instance.args):
441-
# The constraints for generic type parameters depend on variance.
442-
# Include constraints from both directions if invariant.
443-
if tvar.variance != CONTRAVARIANT:
444-
res.extend(infer_constraints(
445-
mapped_arg, instance_arg, self.direction))
446-
if tvar.variance != COVARIANT:
447-
res.extend(infer_constraints(
448-
mapped_arg, instance_arg, neg_op(self.direction)))
446+
# TODO: ParamSpecType
447+
if isinstance(tvar, TypeVarType):
448+
# The constraints for generic type parameters depend on variance.
449+
# Include constraints from both directions if invariant.
450+
if tvar.variance != CONTRAVARIANT:
451+
res.extend(infer_constraints(
452+
mapped_arg, instance_arg, self.direction))
453+
if tvar.variance != COVARIANT:
454+
res.extend(infer_constraints(
455+
mapped_arg, instance_arg, neg_op(self.direction)))
449456
return res
450457
elif (self.direction == SUPERTYPE_OF and
451458
instance.type.has_base(template.type.fullname)):
@@ -454,14 +461,16 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
454461
# N.B: We use zip instead of indexing because the lengths might have
455462
# mismatches during daemon reprocessing.
456463
for tvar, mapped_arg, template_arg in zip(tvars, mapped.args, template.args):
457-
# The constraints for generic type parameters depend on variance.
458-
# Include constraints from both directions if invariant.
459-
if tvar.variance != CONTRAVARIANT:
460-
res.extend(infer_constraints(
461-
template_arg, mapped_arg, self.direction))
462-
if tvar.variance != COVARIANT:
463-
res.extend(infer_constraints(
464-
template_arg, mapped_arg, neg_op(self.direction)))
464+
# TODO: ParamSpecType
465+
if isinstance(tvar, TypeVarType):
466+
# The constraints for generic type parameters depend on variance.
467+
# Include constraints from both directions if invariant.
468+
if tvar.variance != CONTRAVARIANT:
469+
res.extend(infer_constraints(
470+
template_arg, mapped_arg, self.direction))
471+
if tvar.variance != COVARIANT:
472+
res.extend(infer_constraints(
473+
template_arg, mapped_arg, neg_op(self.direction)))
465474
return res
466475
if (template.type.is_protocol and self.direction == SUPERTYPE_OF and
467476
# We avoid infinite recursion for structural subtypes by checking
@@ -536,32 +545,48 @@ def infer_constraints_from_protocol_members(self,
536545

537546
def visit_callable_type(self, template: CallableType) -> List[Constraint]:
538547
if isinstance(self.actual, CallableType):
539-
cactual = self.actual
540-
# FIX verify argument counts
541-
# FIX what if one of the functions is generic
542548
res: List[Constraint] = []
549+
cactual = self.actual
550+
param_spec = template.param_spec()
551+
if param_spec is None:
552+
# FIX verify argument counts
553+
# FIX what if one of the functions is generic
554+
555+
# We can't infer constraints from arguments if the template is Callable[..., T]
556+
# (with literal '...').
557+
if not template.is_ellipsis_args:
558+
# The lengths should match, but don't crash (it will error elsewhere).
559+
for t, a in zip(template.arg_types, cactual.arg_types):
560+
# Negate direction due to function argument type contravariance.
561+
res.extend(infer_constraints(t, a, neg_op(self.direction)))
562+
else:
563+
# TODO: Direction
564+
# TODO: Deal with arguments that come before param spec ones?
565+
res.append(Constraint(param_spec.id,
566+
SUBTYPE_OF,
567+
cactual.copy_modified(ret_type=NoneType())))
543568

544-
# We can't infer constraints from arguments if the template is Callable[..., T] (with
545-
# literal '...').
546-
if not template.is_ellipsis_args:
547-
# The lengths should match, but don't crash (it will error elsewhere).
548-
for t, a in zip(template.arg_types, cactual.arg_types):
549-
# Negate direction due to function argument type contravariance.
550-
res.extend(infer_constraints(t, a, neg_op(self.direction)))
551569
template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
552570
if template.type_guard is not None:
553571
template_ret_type = template.type_guard
554572
if cactual.type_guard is not None:
555573
cactual_ret_type = cactual.type_guard
574+
556575
res.extend(infer_constraints(template_ret_type, cactual_ret_type,
557576
self.direction))
558577
return res
559578
elif isinstance(self.actual, AnyType):
560-
# FIX what if generic
561-
res = self.infer_against_any(template.arg_types, self.actual)
579+
param_spec = template.param_spec()
562580
any_type = AnyType(TypeOfAny.from_another_any, source_any=self.actual)
563-
res.extend(infer_constraints(template.ret_type, any_type, self.direction))
564-
return res
581+
if param_spec is None:
582+
# FIX what if generic
583+
res = self.infer_against_any(template.arg_types, self.actual)
584+
res.extend(infer_constraints(template.ret_type, any_type, self.direction))
585+
return res
586+
else:
587+
return [Constraint(param_spec.id,
588+
SUBTYPE_OF,
589+
callable_with_ellipsis(any_type, any_type, template.fallback))]
565590
elif isinstance(self.actual, Overloaded):
566591
return self.infer_against_overloaded(self.actual, template)
567592
elif isinstance(self.actual, TypeType):

mypy/erasetype.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarId, Instance, TypeVarType,
55
CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType,
66
DeletedType, TypeTranslator, UninhabitedType, TypeType, TypeOfAny, LiteralType, ProperType,
7-
get_proper_type, TypeAliasType
7+
get_proper_type, TypeAliasType, ParamSpecType
88
)
99
from mypy.nodes import ARG_STAR, ARG_STAR2
1010

@@ -57,6 +57,9 @@ def visit_instance(self, t: Instance) -> ProperType:
5757
def visit_type_var(self, t: TypeVarType) -> ProperType:
5858
return AnyType(TypeOfAny.special_form)
5959

60+
def visit_param_spec(self, t: ParamSpecType) -> ProperType:
61+
return AnyType(TypeOfAny.special_form)
62+
6063
def visit_callable_type(self, t: CallableType) -> ProperType:
6164
# We must preserve the fallback type for overload resolution to work.
6265
any_type = AnyType(TypeOfAny.special_form)
@@ -125,6 +128,11 @@ def visit_type_var(self, t: TypeVarType) -> Type:
125128
return self.replacement
126129
return t
127130

131+
def visit_param_spec(self, t: ParamSpecType) -> Type:
132+
if self.erase_id(t.id):
133+
return self.replacement
134+
return t
135+
128136
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
129137
# Type alias target can't contain bound type variables, so
130138
# it is safe to just erase the arguments.

0 commit comments

Comments
 (0)