Skip to content

Implement basic *args support for variadic generics #13889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Callable, Sequence

import mypy.subtypes
from mypy.expandtype import expand_type
from mypy.nodes import Context
from mypy.expandtype import expand_type, expand_unpack_with_variables
from mypy.nodes import ARG_POS, ARG_STAR, Context
from mypy.types import (
AnyType,
CallableType,
Expand All @@ -16,6 +16,7 @@
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UnpackType,
get_proper_type,
)

Expand Down Expand Up @@ -110,7 +111,33 @@ def apply_generic_arguments(
callable = callable.expand_param_spec(nt)

# Apply arguments to argument types.
arg_types = [expand_type(at, id_to_type) for at in callable.arg_types]
var_arg = callable.var_arg()
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
expanded = expand_unpack_with_variables(var_arg.typ, id_to_type)
assert isinstance(expanded, list)
# Handle other cases later.
for t in expanded:
assert not isinstance(t, UnpackType)
star_index = callable.arg_kinds.index(ARG_STAR)
arg_kinds = (
callable.arg_kinds[:star_index]
+ [ARG_POS] * len(expanded)
+ callable.arg_kinds[star_index + 1 :]
)
arg_names = (
callable.arg_names[:star_index]
+ [None] * len(expanded)
+ callable.arg_names[star_index + 1 :]
)
arg_types = (
[expand_type(at, id_to_type) for at in callable.arg_types[:star_index]]
+ expanded
+ [expand_type(at, id_to_type) for at in callable.arg_types[star_index + 1 :]]
)
else:
arg_types = [expand_type(at, id_to_type) for at in callable.arg_types]
arg_kinds = callable.arg_kinds
arg_names = callable.arg_names

# Apply arguments to TypeGuard if any.
if callable.type_guard is not None:
Expand All @@ -126,4 +153,6 @@ def apply_generic_arguments(
ret_type=expand_type(callable.ret_type, id_to_type),
variables=remaining_tvars,
type_guard=type_guard,
arg_kinds=arg_kinds,
arg_names=arg_names,
)
12 changes: 11 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@
UnboundType,
UninhabitedType,
UnionType,
UnpackType,
flatten_nested_unions,
get_proper_type,
get_proper_types,
Expand Down Expand Up @@ -1170,7 +1171,16 @@ def check_func_def(
ctx = typ
self.fail(message_registry.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT, ctx)
if typ.arg_kinds[i] == nodes.ARG_STAR:
if not isinstance(arg_type, ParamSpecType):
if isinstance(arg_type, ParamSpecType):
pass
elif isinstance(arg_type, UnpackType):
arg_type = TupleType(
[arg_type],
fallback=self.named_generic_type(
"builtins.tuple", [self.named_type("builtins.object")]
),
)
else:
# builtins.tuple[T] is typing.Tuple[T, ...]
arg_type = self.named_generic_type("builtins.tuple", [arg_type])
elif typ.arg_kinds[i] == nodes.ARG_STAR2:
Expand Down
5 changes: 4 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@
TypedDictType,
TypeOfAny,
TypeType,
TypeVarTupleType,
TypeVarType,
UninhabitedType,
UnionType,
Expand Down Expand Up @@ -1397,7 +1398,9 @@ def check_callable_call(
)

if callee.is_generic():
need_refresh = any(isinstance(v, ParamSpecType) for v in callee.variables)
need_refresh = any(
isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables
)
callee = freshen_function_type_vars(callee)
callee = self.infer_function_type_arguments_using_context(callee, context)
callee = self.infer_function_type_arguments(
Expand Down
44 changes: 34 additions & 10 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,41 @@ def infer_constraints_for_callable(
mapper = ArgTypeExpander(context)

for i, actuals in enumerate(formal_to_actual):
for actual in actuals:
actual_arg_type = arg_types[actual]
if actual_arg_type is None:
continue
if isinstance(callee.arg_types[i], UnpackType):
unpack_type = callee.arg_types[i]
assert isinstance(unpack_type, UnpackType)

# In this case we are binding all of the actuals to *args
# and we want a constraint that the typevar tuple being unpacked
# is equal to a type list of all the actuals.
actual_types = []
for actual in actuals:
actual_arg_type = arg_types[actual]
if actual_arg_type is None:
continue

actual_type = mapper.expand_actual_type(
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
)
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
constraints.extend(c)
actual_types.append(
mapper.expand_actual_type(
actual_arg_type,
arg_kinds[actual],
callee.arg_names[i],
callee.arg_kinds[i],
)
)

assert isinstance(unpack_type.type, TypeVarTupleType)
constraints.append(Constraint(unpack_type.type, SUPERTYPE_OF, TypeList(actual_types)))
else:
for actual in actuals:
actual_arg_type = arg_types[actual]
if actual_arg_type is None:
continue

actual_type = mapper.expand_actual_type(
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
)
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
constraints.extend(c)

return constraints

Expand Down Expand Up @@ -165,7 +190,6 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> list[Cons


def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]:

orig_template = template
template = get_proper_type(template)
actual = get_proper_type(actual)
Expand Down
74 changes: 48 additions & 26 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Iterable, Mapping, Sequence, TypeVar, cast, overload

from mypy.nodes import ARG_STAR
from mypy.types import (
AnyType,
CallableType,
Expand Down Expand Up @@ -213,31 +214,7 @@ def visit_unpack_type(self, t: UnpackType) -> Type:
assert False, "Mypy bug: unpacking must happen at a higher level"

def expand_unpack(self, t: UnpackType) -> list[Type] | Instance | AnyType | None:
"""May return either a list of types to unpack to, any, or a single
variable length tuple. The latter may not be valid in all contexts.
"""
if isinstance(t.type, TypeVarTupleType):
repl = get_proper_type(self.variables.get(t.type.id, t))
if isinstance(repl, TupleType):
return repl.items
if isinstance(repl, TypeList):
return repl.items
elif isinstance(repl, Instance) and repl.type.fullname == "builtins.tuple":
return repl
elif isinstance(repl, AnyType):
# tuple[Any, ...] would be better, but we don't have
# the type info to construct that type here.
return repl
elif isinstance(repl, TypeVarTupleType):
return [UnpackType(typ=repl)]
elif isinstance(repl, UnpackType):
return [repl]
elif isinstance(repl, UninhabitedType):
return None
else:
raise NotImplementedError(f"Invalid type replacement to expand: {repl}")
else:
raise NotImplementedError(f"Invalid type to expand: {t.type}")
return expand_unpack_with_variables(t, self.variables)

def visit_parameters(self, t: Parameters) -> Type:
return t.copy_modified(arg_types=self.expand_types(t.arg_types))
Expand Down Expand Up @@ -267,8 +244,23 @@ def visit_callable_type(self, t: CallableType) -> Type:
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
)

var_arg = t.var_arg()
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
expanded = self.expand_unpack(var_arg.typ)
# Handle other cases later.
assert isinstance(expanded, list)
assert len(expanded) == 1 and isinstance(expanded[0], UnpackType)
star_index = t.arg_kinds.index(ARG_STAR)
arg_types = (
self.expand_types(t.arg_types[:star_index])
+ expanded
+ self.expand_types(t.arg_types[star_index + 1 :])
)
else:
arg_types = self.expand_types(t.arg_types)

return t.copy_modified(
arg_types=self.expand_types(t.arg_types),
arg_types=arg_types,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
)
Expand Down Expand Up @@ -361,3 +353,33 @@ def expand_types(self, types: Iterable[Type]) -> list[Type]:
for t in types:
a.append(t.accept(self))
return a


def expand_unpack_with_variables(
t: UnpackType, variables: Mapping[TypeVarId, Type]
) -> list[Type] | Instance | AnyType | None:
"""May return either a list of types to unpack to, any, or a single
variable length tuple. The latter may not be valid in all contexts.
"""
if isinstance(t.type, TypeVarTupleType):
repl = get_proper_type(variables.get(t.type.id, t))
if isinstance(repl, TupleType):
return repl.items
if isinstance(repl, TypeList):
return repl.items
elif isinstance(repl, Instance) and repl.type.fullname == "builtins.tuple":
return repl
elif isinstance(repl, AnyType):
# tuple[Any, ...] would be better, but we don't have
# the type info to construct that type here.
return repl
elif isinstance(repl, TypeVarTupleType):
return [UnpackType(typ=repl)]
elif isinstance(repl, UnpackType):
return [repl]
elif isinstance(repl, UninhabitedType):
return None
else:
raise NotImplementedError(f"Invalid type replacement to expand: {repl}")
else:
raise NotImplementedError(f"Invalid type to expand: {t.type}")
4 changes: 4 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
TypedDictType,
TypeOfAny,
TypeType,
TypeVarTupleType,
TypeVarType,
UnboundType,
UninhabitedType,
Expand Down Expand Up @@ -2263,6 +2264,9 @@ def format_literal_value(typ: LiteralType) -> str:
elif isinstance(typ, TypeVarType):
# This is similar to non-generic instance types.
return typ.name
elif isinstance(typ, TypeVarTupleType):
# This is similar to non-generic instance types.
return typ.name
elif isinstance(typ, ParamSpecType):
# Concatenate[..., P]
if typ.prefix.arg_types:
Expand Down
16 changes: 16 additions & 0 deletions test-data/unit/check-typevar-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -346,4 +346,20 @@ expect_variadic_array(u)
expect_variadic_array_2(u)


[builtins fixtures/tuple.pyi]

[case testPep646TypeVarStarArgs]
from typing import Tuple
from typing_extensions import TypeVarTuple, Unpack

Ts = TypeVarTuple("Ts")

# TODO: add less trivial tests with prefix/suffix etc.
# TODO: add tests that call with a type var tuple instead of just args.
def args_to_tuple(*args: Unpack[Ts]) -> Tuple[Unpack[Ts]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this PR work for a non-typevartuple use case? I have seen people use something like this def foo(*args: Unpack[tuple[int, str]]) (but it will be probably more useful with a type alias, to re-use in different functions, similar to how we allow TypedDicts to re-use types for **kwds).

Btw currently this use case may cause a crash, see first example in #13790 (comment)

If this PR helps with this direction, could you please add a test case? Otherwise consider this a sub-feature request.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another common use case with *args which crashes currently.

from typing import Callable
from typing_extensions import TypeVarTuple, Unpack

Ts = TypeVarTuple("Ts")

def a(var: int) -> None: ...

def func(f: Callable[[Unpack[Ts]], None], *args: Unpack[Ts]) -> None:
    f(*args)

func(a, 42)

reveal_type(args) # N: Revealed type is "Tuple[Unpack[Ts`-1]]"
return args

reveal_type(args_to_tuple(1, 'a')) # N: Revealed type is "Tuple[Literal[1]?, Literal['a']?]"

[builtins fixtures/tuple.pyi]