Skip to content

Commit c810a9c

Browse files
authored
Implement basic *args support for variadic generics (#13889)
This implements the most basic support for the *args feature but various edge cases are not handled in this PR because of the large volume of places that needed to be modified to support this. In particular, we need to special handle the ARG_STAR argument in several places for the case where the type is a UnpackType. Finally when we actually check a function we need to construct a TupleType instead of a builtins.tuple.
1 parent f12faae commit c810a9c

File tree

7 files changed

+149
-41
lines changed

7 files changed

+149
-41
lines changed

mypy/applytype.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from typing import Callable, Sequence
44

55
import mypy.subtypes
6-
from mypy.expandtype import expand_type
7-
from mypy.nodes import Context
6+
from mypy.expandtype import expand_type, expand_unpack_with_variables
7+
from mypy.nodes import ARG_POS, ARG_STAR, Context
88
from mypy.types import (
99
AnyType,
1010
CallableType,
@@ -16,6 +16,7 @@
1616
TypeVarLikeType,
1717
TypeVarTupleType,
1818
TypeVarType,
19+
UnpackType,
1920
get_proper_type,
2021
)
2122

@@ -110,7 +111,33 @@ def apply_generic_arguments(
110111
callable = callable.expand_param_spec(nt)
111112

112113
# Apply arguments to argument types.
113-
arg_types = [expand_type(at, id_to_type) for at in callable.arg_types]
114+
var_arg = callable.var_arg()
115+
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
116+
expanded = expand_unpack_with_variables(var_arg.typ, id_to_type)
117+
assert isinstance(expanded, list)
118+
# Handle other cases later.
119+
for t in expanded:
120+
assert not isinstance(t, UnpackType)
121+
star_index = callable.arg_kinds.index(ARG_STAR)
122+
arg_kinds = (
123+
callable.arg_kinds[:star_index]
124+
+ [ARG_POS] * len(expanded)
125+
+ callable.arg_kinds[star_index + 1 :]
126+
)
127+
arg_names = (
128+
callable.arg_names[:star_index]
129+
+ [None] * len(expanded)
130+
+ callable.arg_names[star_index + 1 :]
131+
)
132+
arg_types = (
133+
[expand_type(at, id_to_type) for at in callable.arg_types[:star_index]]
134+
+ expanded
135+
+ [expand_type(at, id_to_type) for at in callable.arg_types[star_index + 1 :]]
136+
)
137+
else:
138+
arg_types = [expand_type(at, id_to_type) for at in callable.arg_types]
139+
arg_kinds = callable.arg_kinds
140+
arg_names = callable.arg_names
114141

115142
# Apply arguments to TypeGuard if any.
116143
if callable.type_guard is not None:
@@ -126,4 +153,6 @@ def apply_generic_arguments(
126153
ret_type=expand_type(callable.ret_type, id_to_type),
127154
variables=remaining_tvars,
128155
type_guard=type_guard,
156+
arg_kinds=arg_kinds,
157+
arg_names=arg_names,
129158
)

mypy/checker.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@
202202
UnboundType,
203203
UninhabitedType,
204204
UnionType,
205+
UnpackType,
205206
flatten_nested_unions,
206207
get_proper_type,
207208
get_proper_types,
@@ -1170,7 +1171,16 @@ def check_func_def(
11701171
ctx = typ
11711172
self.fail(message_registry.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT, ctx)
11721173
if typ.arg_kinds[i] == nodes.ARG_STAR:
1173-
if not isinstance(arg_type, ParamSpecType):
1174+
if isinstance(arg_type, ParamSpecType):
1175+
pass
1176+
elif isinstance(arg_type, UnpackType):
1177+
arg_type = TupleType(
1178+
[arg_type],
1179+
fallback=self.named_generic_type(
1180+
"builtins.tuple", [self.named_type("builtins.object")]
1181+
),
1182+
)
1183+
else:
11741184
# builtins.tuple[T] is typing.Tuple[T, ...]
11751185
arg_type = self.named_generic_type("builtins.tuple", [arg_type])
11761186
elif typ.arg_kinds[i] == nodes.ARG_STAR2:

mypy/checkexpr.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@
145145
TypedDictType,
146146
TypeOfAny,
147147
TypeType,
148+
TypeVarTupleType,
148149
TypeVarType,
149150
UninhabitedType,
150151
UnionType,
@@ -1397,7 +1398,9 @@ def check_callable_call(
13971398
)
13981399

13991400
if callee.is_generic():
1400-
need_refresh = any(isinstance(v, ParamSpecType) for v in callee.variables)
1401+
need_refresh = any(
1402+
isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables
1403+
)
14011404
callee = freshen_function_type_vars(callee)
14021405
callee = self.infer_function_type_arguments_using_context(callee, context)
14031406
callee = self.infer_function_type_arguments(

mypy/constraints.py

+34-10
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,41 @@ def infer_constraints_for_callable(
111111
mapper = ArgTypeExpander(context)
112112

113113
for i, actuals in enumerate(formal_to_actual):
114-
for actual in actuals:
115-
actual_arg_type = arg_types[actual]
116-
if actual_arg_type is None:
117-
continue
114+
if isinstance(callee.arg_types[i], UnpackType):
115+
unpack_type = callee.arg_types[i]
116+
assert isinstance(unpack_type, UnpackType)
117+
118+
# In this case we are binding all of the actuals to *args
119+
# and we want a constraint that the typevar tuple being unpacked
120+
# is equal to a type list of all the actuals.
121+
actual_types = []
122+
for actual in actuals:
123+
actual_arg_type = arg_types[actual]
124+
if actual_arg_type is None:
125+
continue
118126

119-
actual_type = mapper.expand_actual_type(
120-
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
121-
)
122-
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
123-
constraints.extend(c)
127+
actual_types.append(
128+
mapper.expand_actual_type(
129+
actual_arg_type,
130+
arg_kinds[actual],
131+
callee.arg_names[i],
132+
callee.arg_kinds[i],
133+
)
134+
)
135+
136+
assert isinstance(unpack_type.type, TypeVarTupleType)
137+
constraints.append(Constraint(unpack_type.type, SUPERTYPE_OF, TypeList(actual_types)))
138+
else:
139+
for actual in actuals:
140+
actual_arg_type = arg_types[actual]
141+
if actual_arg_type is None:
142+
continue
143+
144+
actual_type = mapper.expand_actual_type(
145+
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
146+
)
147+
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
148+
constraints.extend(c)
124149

125150
return constraints
126151

@@ -165,7 +190,6 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> list[Cons
165190

166191

167192
def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]:
168-
169193
orig_template = template
170194
template = get_proper_type(template)
171195
actual = get_proper_type(actual)

mypy/expandtype.py

+48-26
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Iterable, Mapping, Sequence, TypeVar, cast, overload
44

5+
from mypy.nodes import ARG_STAR
56
from mypy.types import (
67
AnyType,
78
CallableType,
@@ -213,31 +214,7 @@ def visit_unpack_type(self, t: UnpackType) -> Type:
213214
assert False, "Mypy bug: unpacking must happen at a higher level"
214215

215216
def expand_unpack(self, t: UnpackType) -> list[Type] | Instance | AnyType | None:
216-
"""May return either a list of types to unpack to, any, or a single
217-
variable length tuple. The latter may not be valid in all contexts.
218-
"""
219-
if isinstance(t.type, TypeVarTupleType):
220-
repl = get_proper_type(self.variables.get(t.type.id, t))
221-
if isinstance(repl, TupleType):
222-
return repl.items
223-
if isinstance(repl, TypeList):
224-
return repl.items
225-
elif isinstance(repl, Instance) and repl.type.fullname == "builtins.tuple":
226-
return repl
227-
elif isinstance(repl, AnyType):
228-
# tuple[Any, ...] would be better, but we don't have
229-
# the type info to construct that type here.
230-
return repl
231-
elif isinstance(repl, TypeVarTupleType):
232-
return [UnpackType(typ=repl)]
233-
elif isinstance(repl, UnpackType):
234-
return [repl]
235-
elif isinstance(repl, UninhabitedType):
236-
return None
237-
else:
238-
raise NotImplementedError(f"Invalid type replacement to expand: {repl}")
239-
else:
240-
raise NotImplementedError(f"Invalid type to expand: {t.type}")
217+
return expand_unpack_with_variables(t, self.variables)
241218

242219
def visit_parameters(self, t: Parameters) -> Type:
243220
return t.copy_modified(arg_types=self.expand_types(t.arg_types))
@@ -267,8 +244,23 @@ def visit_callable_type(self, t: CallableType) -> Type:
267244
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
268245
)
269246

247+
var_arg = t.var_arg()
248+
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
249+
expanded = self.expand_unpack(var_arg.typ)
250+
# Handle other cases later.
251+
assert isinstance(expanded, list)
252+
assert len(expanded) == 1 and isinstance(expanded[0], UnpackType)
253+
star_index = t.arg_kinds.index(ARG_STAR)
254+
arg_types = (
255+
self.expand_types(t.arg_types[:star_index])
256+
+ expanded
257+
+ self.expand_types(t.arg_types[star_index + 1 :])
258+
)
259+
else:
260+
arg_types = self.expand_types(t.arg_types)
261+
270262
return t.copy_modified(
271-
arg_types=self.expand_types(t.arg_types),
263+
arg_types=arg_types,
272264
ret_type=t.ret_type.accept(self),
273265
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
274266
)
@@ -361,3 +353,33 @@ def expand_types(self, types: Iterable[Type]) -> list[Type]:
361353
for t in types:
362354
a.append(t.accept(self))
363355
return a
356+
357+
358+
def expand_unpack_with_variables(
359+
t: UnpackType, variables: Mapping[TypeVarId, Type]
360+
) -> list[Type] | Instance | AnyType | None:
361+
"""May return either a list of types to unpack to, any, or a single
362+
variable length tuple. The latter may not be valid in all contexts.
363+
"""
364+
if isinstance(t.type, TypeVarTupleType):
365+
repl = get_proper_type(variables.get(t.type.id, t))
366+
if isinstance(repl, TupleType):
367+
return repl.items
368+
if isinstance(repl, TypeList):
369+
return repl.items
370+
elif isinstance(repl, Instance) and repl.type.fullname == "builtins.tuple":
371+
return repl
372+
elif isinstance(repl, AnyType):
373+
# tuple[Any, ...] would be better, but we don't have
374+
# the type info to construct that type here.
375+
return repl
376+
elif isinstance(repl, TypeVarTupleType):
377+
return [UnpackType(typ=repl)]
378+
elif isinstance(repl, UnpackType):
379+
return [repl]
380+
elif isinstance(repl, UninhabitedType):
381+
return None
382+
else:
383+
raise NotImplementedError(f"Invalid type replacement to expand: {repl}")
384+
else:
385+
raise NotImplementedError(f"Invalid type to expand: {t.type}")

mypy/messages.py

+4
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
TypedDictType,
8181
TypeOfAny,
8282
TypeType,
83+
TypeVarTupleType,
8384
TypeVarType,
8485
UnboundType,
8586
UninhabitedType,
@@ -2263,6 +2264,9 @@ def format_literal_value(typ: LiteralType) -> str:
22632264
elif isinstance(typ, TypeVarType):
22642265
# This is similar to non-generic instance types.
22652266
return typ.name
2267+
elif isinstance(typ, TypeVarTupleType):
2268+
# This is similar to non-generic instance types.
2269+
return typ.name
22662270
elif isinstance(typ, ParamSpecType):
22672271
# Concatenate[..., P]
22682272
if typ.prefix.arg_types:

test-data/unit/check-typevar-tuple.test

+16
Original file line numberDiff line numberDiff line change
@@ -346,4 +346,20 @@ expect_variadic_array(u)
346346
expect_variadic_array_2(u)
347347

348348

349+
[builtins fixtures/tuple.pyi]
350+
351+
[case testPep646TypeVarStarArgs]
352+
from typing import Tuple
353+
from typing_extensions import TypeVarTuple, Unpack
354+
355+
Ts = TypeVarTuple("Ts")
356+
357+
# TODO: add less trivial tests with prefix/suffix etc.
358+
# TODO: add tests that call with a type var tuple instead of just args.
359+
def args_to_tuple(*args: Unpack[Ts]) -> Tuple[Unpack[Ts]]:
360+
reveal_type(args) # N: Revealed type is "Tuple[Unpack[Ts`-1]]"
361+
return args
362+
363+
reveal_type(args_to_tuple(1, 'a')) # N: Revealed type is "Tuple[Literal[1]?, Literal['a']?]"
364+
349365
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)