diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 42e47ed9aa0f..c25668ba5ae4 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -136,6 +136,7 @@ LiteralValue, NoneType, Overloaded, + Parameters, ParamSpecFlavor, ParamSpecType, PartialType, @@ -1429,6 +1430,7 @@ def check_callable_call( need_refresh = any( isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables ) + old_callee = callee callee = freshen_function_type_vars(callee) callee = self.infer_function_type_arguments_using_context(callee, context) if need_refresh: @@ -1443,7 +1445,7 @@ def check_callable_call( lambda i: self.accept(args[i]), ) callee = self.infer_function_type_arguments( - callee, args, arg_kinds, formal_to_actual, context + callee, args, arg_kinds, formal_to_actual, context, old_callee ) if need_refresh: formal_to_actual = map_actuals_to_formals( @@ -1733,6 +1735,7 @@ def infer_function_type_arguments( arg_kinds: list[ArgKind], formal_to_actual: list[list[int]], context: Context, + unfreshened_callee_type: CallableType, ) -> CallableType: """Infer the type arguments for a generic callee type. @@ -1776,6 +1779,28 @@ def infer_function_type_arguments( callee_type, args, arg_kinds, formal_to_actual, inferred_args, context ) + return_type = get_proper_type(callee_type.ret_type) + if isinstance(return_type, CallableType): + # fixup: + # def [T] () -> def (T) -> T + # into + # def () -> def [T] (T) -> T + for i, argument in enumerate(inferred_args): + if isinstance(get_proper_type(argument), UninhabitedType): + # un-"freshen" the type variable :^) + variable = unfreshened_callee_type.variables[i] + inferred_args[i] = variable + + # handle multiple type variables + return_type = return_type.copy_modified( + variables=[*return_type.variables, variable] + ) + + callee_type = callee_type.copy_modified( + # am I allowed to assign the get_proper_type'd thing? + ret_type=return_type + ) + if ( callee_type.special_sig == "dict" and len(inferred_args) == 2 @@ -4070,6 +4095,20 @@ def apply_type_arguments_to_callable( tp = get_proper_type(tp) if isinstance(tp, CallableType): + if ( + len(tp.variables) == 1 + and isinstance(tp.variables[0], ParamSpecType) + and ( + len(args) != 1 + or not isinstance( + get_proper_type(args[0]), (Parameters, ParamSpecType, AnyType) + ) + ) + ): + # TODO: I don't think AnyType here is valid in the general case, there's 2 cases: + # 1. invalid paramspec expression (in which case we should transform it into an ellipsis) + # 2. user passed it (in which case we should pass it into Parameters(...)) + args = [Parameters(args, [nodes.ARG_POS for _ in args], [None for _ in args])] if len(tp.variables) != len(args): if tp.is_type_obj() and tp.type_object().fullname == "builtins.tuple": # TODO: Specialize the callable for the type arguments diff --git a/mypy/constraints.py b/mypy/constraints.py index c8c3c7933b6e..689d294b9ef6 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, List, Sequence +from typing import TYPE_CHECKING, Iterable, List, Sequence, Union from typing_extensions import Final import mypy.subtypes @@ -176,12 +176,13 @@ def infer_constraints_for_callable( def infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]: """Infer type constraints. - Match a template type, which may contain type variable references, - recursively against a type which does not contain (the same) type - variable references. The result is a list of type constrains of - form 'T is a supertype/subtype of x', where T is a type variable - present in the template and x is a type without reference to type - variables present in the template. + Match a template type, which may contain type variable and parameter + specification references, recursively against a type which does not + contain (the same) type variable and parameter specification references. + The result is a list of type constraints of form 'T is a supertype/subtype + of x', where T is a type variable present in the template or a parameter + specification without its prefix and x is a type without reference to type + variables nor parameters present in the template. Assume T and S are type variables. Now the following results can be calculated (read as '(template, actual) --> result'): @@ -192,6 +193,23 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> list[Cons ((T, S), (X, Y)) --> T :> X and S :> Y (X[T], Any) --> T <: Any and T :> Any + Assume P and Q are prefix-less parameter specifications. The following + results can be calculated in a similar format: + + (P, [...W]) --> P :> [...W] + (X[P], X[[...W]]) --> P :> [...W] + // note that parameter specifications are *always* contravariant as + // they echo Callable arguments. + ((P, P), ([...W], [...U])) --> P :> [...W] and P :> [...U] + ((P, Q), ([...W], [...U])) --> P :> [...W] and Q :> [...U] + (P, ...) --> P :> ... + + With prefixes (note that I am not sure these cases are implemented): + + ([...Z, P], [...Z, ...W]) --> P :> [...W] + ([...Z, P], Q) --> [...Z, P] :> Q + (P, [...Z, Q]) --> P :> [...Z, Q] + The constraints are represented as Constraint objects. """ if any( @@ -695,19 +713,37 @@ def visit_instance(self, template: Instance) -> list[Constraint]: from_concat = bool(prefix.arg_types) or suffix.from_concatenate suffix = suffix.copy_modified(from_concatenate=from_concat) + prefix = mapped_arg.prefix + length = len(prefix.arg_types) if isinstance(suffix, Parameters) or isinstance(suffix, CallableType): # no such thing as variance for ParamSpecs # TODO: is there a case I am missing? - # TODO: constraints between prefixes - prefix = mapped_arg.prefix - suffix = suffix.copy_modified( - suffix.arg_types[len(prefix.arg_types) :], - suffix.arg_kinds[len(prefix.arg_kinds) :], - suffix.arg_names[len(prefix.arg_names) :], + res.append( + Constraint( + mapped_arg, + SUPERTYPE_OF, + suffix.copy_modified( + arg_types=suffix.arg_types[length:], + arg_kinds=suffix.arg_kinds[length:], + arg_names=suffix.arg_names[length:], + ), + ) ) - res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix)) elif isinstance(suffix, ParamSpecType): - res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix)) + suffix_prefix = suffix.prefix + res.append( + Constraint( + mapped_arg, + SUPERTYPE_OF, + suffix.copy_modified( + prefix=suffix_prefix.copy_modified( + arg_types=suffix_prefix.arg_types[length:], + arg_kinds=suffix_prefix.arg_kinds[length:], + arg_names=suffix_prefix.arg_names[length:], + ) + ), + ) + ) else: # This case should have been handled above. assert not isinstance(tvar, TypeVarTupleType) @@ -918,14 +954,23 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # sometimes, it appears we try to get constraints between two paramspec callables? # TODO: Direction - # TODO: check the prefixes match prefix = param_spec.prefix prefix_len = len(prefix.arg_types) cactual_ps = cactual.param_spec() + cactual_prefix: Union[Parameters, CallableType] + if cactual_ps: + cactual_prefix = cactual_ps.prefix + else: + cactual_prefix = cactual + + max_prefix_len = len( + [k for k in cactual_prefix.arg_kinds if k in (ARG_POS, ARG_OPT)] + ) + prefix_len = min(prefix_len, max_prefix_len) + + # we could check the prefixes match here, but that should be caught elsewhere. if not cactual_ps: - max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)]) - prefix_len = min(prefix_len, max_prefix_len) res.append( Constraint( param_spec, @@ -939,7 +984,22 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: ) ) else: - res.append(Constraint(param_spec, SUBTYPE_OF, cactual_ps)) + # guaranteed due to if conditions + assert isinstance(cactual_prefix, Parameters) + + res.append( + Constraint( + param_spec, + SUBTYPE_OF, + cactual_ps.copy_modified( + prefix=cactual_prefix.copy_modified( + arg_types=cactual_prefix.arg_types[prefix_len:], + arg_kinds=cactual_prefix.arg_kinds[prefix_len:], + arg_names=cactual_prefix.arg_names[prefix_len:], + ) + ), + ) + ) # compare prefixes cactual_prefix = cactual.copy_modified( diff --git a/mypy/erasetype.py b/mypy/erasetype.py index 6533d0c4e0f9..43dbe3a0d8e4 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -172,7 +172,11 @@ def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: def visit_param_spec(self, t: ParamSpecType) -> Type: if self.erase_id(t.id): - return self.replacement + return t.prefix.copy_modified( + arg_types=t.prefix.arg_types + [self.replacement, self.replacement], + arg_kinds=t.prefix.arg_kinds + [ARG_STAR, ARG_STAR2], + arg_names=t.prefix.arg_names + [None, None], + ) return t def visit_type_alias_type(self, t: TypeAliasType) -> Type: diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 21c3a592669e..5e12bf078d54 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -138,7 +138,6 @@ def freshen_function_type_vars(callee: F) -> F: if isinstance(v, TypeVarType): tv: TypeVarLikeType = TypeVarType.new_unification_variable(v) elif isinstance(v, TypeVarTupleType): - assert isinstance(v, TypeVarTupleType) tv = TypeVarTupleType.new_unification_variable(v) else: assert isinstance(v, ParamSpecType) diff --git a/mypy/join.py b/mypy/join.py index 62d256f4440f..ec0ba865367c 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -97,7 +97,7 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType: else: # ParamSpec type variables behave the same, independent of variance if not is_equivalent(ta, sa): - return get_proper_type(type_var.upper_bound) + return object_from_instance(t) new_type = join_types(ta, sa, self) assert new_type is not None args.append(new_type) @@ -311,9 +311,11 @@ def visit_type_var(self, t: TypeVarType) -> ProperType: return self.default(self.s) def visit_param_spec(self, t: ParamSpecType) -> ProperType: + # TODO: should this mirror the `isinstance(...) ...` above? if self.s == t: return t - return self.default(self.s) + else: + return self.default(self.s) def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType: if self.s == t: diff --git a/mypy/nodes.py b/mypy/nodes.py index 8365e508f8e0..25ed0ea6ac36 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2519,6 +2519,12 @@ class ParamSpecExpr(TypeVarLikeExpr): __match_args__ = ("name", "upper_bound") + # TODO: Technically the variance cannot be customized. Nor can the upper bound. + def __init__( + self, name: str, fullname: str, upper_bound: mypy.types.Type, variance: int = INVARIANT + ) -> None: + super().__init__(name, fullname, upper_bound, variance) + def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_paramspec_expr(self) diff --git a/mypy/semanal.py b/mypy/semanal.py index fee7d9d9a520..e530870eb456 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -4168,7 +4168,7 @@ def process_paramspec_declaration(self, s: AssignmentStmt) -> bool: if not call.analyzed: paramspec_var = ParamSpecExpr( - name, self.qualified_name(name), self.object_type(), INVARIANT + name, self.qualified_name(name), self.top_caller(), INVARIANT ) paramspec_var.line = call.line call.analyzed = paramspec_var @@ -5612,6 +5612,14 @@ def lookup_fully_qualified_or_none(self, fullname: str) -> SymbolTableNode | Non def object_type(self) -> Instance: return self.named_type("builtins.object") + def top_caller(self) -> Parameters: + return Parameters( + arg_types=[self.object_type(), self.object_type()], + arg_kinds=[ARG_STAR, ARG_STAR2], + arg_names=[None, None], + is_ellipsis_args=True, + ) + def str_type(self) -> Instance: return self.named_type("builtins.str") diff --git a/mypy/strconv.py b/mypy/strconv.py index b2e9da5dbf6a..2fad788cb9f3 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -484,8 +484,9 @@ def visit_paramspec_expr(self, o: mypy.nodes.ParamSpecExpr) -> str: a += ["Variance(COVARIANT)"] if o.variance == mypy.nodes.CONTRAVARIANT: a += ["Variance(CONTRAVARIANT)"] - if not mypy.types.is_named_instance(o.upper_bound, "builtins.object"): - a += [f"UpperBound({o.upper_bound})"] + # ParamSpecs do not have upper bounds!!! (should this be left for future proofing?) + # if not mypy.types.is_named_instance(o.upper_bound, "builtins.object"): + # a += [f"UpperBound({o.upper_bound})"] return self.dump(a, o) def visit_type_var_tuple_expr(self, o: mypy.nodes.TypeVarTupleExpr) -> str: diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 9c6518b9e487..0da2cc192806 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -592,6 +592,8 @@ def check_mixed( ): nominal = False else: + # TODO: I'm *pretty* sure `CONTRAVARIANT` should be here... + # But it's erroring! if not check_type_parameter( lefta, righta, COVARIANT, self.proper_subtype, self.subtype_context ): diff --git a/mypy/typeanal.py b/mypy/typeanal.py index f3329af6207a..63773da656fa 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -447,6 +447,9 @@ def pack_paramspec_args(self, an_args: Sequence[Type]) -> list[Type]: if count > 0: first_arg = get_proper_type(an_args[0]) if not (count == 1 and isinstance(first_arg, (Parameters, ParamSpecType, AnyType))): + # TODO: I don't think AnyType here is valid in the general case, there's 2 cases: + # 1. invalid paramspec expression (in which case we should transform it into an ellipsis) + # 2. user passed it (in which case we should pass it into Parameters(...)) return [Parameters(an_args, [ARG_POS] * count, [None] * count)] return list(an_args) diff --git a/mypy/types.py b/mypy/types.py index 994eb290fff3..7e18e54e2b52 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -651,7 +651,8 @@ class ParamSpecType(TypeVarLikeType): The upper_bound is really used as a fallback type -- it's shared with TypeVarType for simplicity. It can't be specified by the user and the value is directly derived from the flavor (currently - always just 'object'). + always just '(*Any, **Any)' or '(*object, **object)' depending on + context). """ __slots__ = ("flavor", "prefix") @@ -696,13 +697,14 @@ def copy_modified( id: Bogus[TypeVarId | int] = _dummy, flavor: int = _dummy_int, prefix: Bogus[Parameters] = _dummy, + upper_bound: Bogus[Type] = _dummy, ) -> ParamSpecType: return ParamSpecType( self.name, self.fullname, id if id is not _dummy else self.id, flavor if flavor != _dummy_int else self.flavor, - self.upper_bound, + upper_bound if upper_bound is not _dummy else self.upper_bound, line=self.line, column=self.column, prefix=prefix if prefix is not _dummy else self.prefix, @@ -1986,7 +1988,18 @@ def param_spec(self) -> ParamSpecType | None: # TODO: confirm that all arg kinds are positional prefix = Parameters(self.arg_types[:-2], self.arg_kinds[:-2], self.arg_names[:-2]) - return arg_type.copy_modified(flavor=ParamSpecFlavor.BARE, prefix=prefix) + # TODO: should this take in `object`s? + any_type = AnyType(TypeOfAny.special_form) + return arg_type.copy_modified( + flavor=ParamSpecFlavor.BARE, + prefix=prefix, + upper_bound=Parameters( + arg_types=[any_type, any_type], + arg_kinds=[ARG_STAR, ARG_STAR2], + arg_names=[None, None], + is_ellipsis_args=True, + ), + ) def expand_param_spec( self, c: CallableType | Parameters, no_prefix: bool = False diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 2dc19d319a0d..757e5205762e 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3103,8 +3103,11 @@ T = TypeVar('T') def f(x: Optional[T] = None) -> Callable[..., T]: ... -x = f() # E: Need type annotation for "x" +# TODO: should this warn about needed an annotation? This behavior still _works_... +x = f() +reveal_type(x) # N: Revealed type is "def [T] (*Any, **Any) -> T`-1" y = x +reveal_type(y) # N: Revealed type is "def [T] (*Any, **Any) -> T`-1" [case testDontNeedAnnotationForCallable] from typing import TypeVar, Optional, Callable, NoReturn diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 7ef0485f7841..bcc50afc27e5 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1521,3 +1521,51 @@ def identity(func: Callable[P, None]) -> Callable[P, None]: ... @identity def f(f: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ... [builtins fixtures/paramspec.pyi] + +[case testRemoveSharedPrefixForConstraining] +# copied essentially verbatim from testing usages of ParamSpec +from typing import TypeVar, Callable, Generic +from typing_extensions import ParamSpec, Concatenate + +P = ParamSpec("P") +T = TypeVar("T") +V = TypeVar("V") +R = TypeVar("R") + +def command_builder() -> Callable[[Callable[Concatenate[T, P], R]], Callable[P, Callable[[T], R]]]: + def transformer(f: Callable[Concatenate[T, P], R]) -> Callable[P, Callable[[T], R]]: + def returned(*args: P.args, **kwargs: P.kwargs) -> Callable[[T], R]: + def returned_transformer(z: T) -> R: + return f(z, *args, **kwargs) + + return returned_transformer + + return returned + + reveal_type(transformer) # N: Revealed type is "def [T, P, R] (f: def (T`-1, *P.args, **P.kwargs) -> R`-3) -> def (*P.args, **P.kwargs) -> def (T`-1) -> R`-3" + return transformer + +class Example(Generic[P]): + pass + +@command_builder() +def test(ex: Example[P]) -> Example[Concatenate[int, P]]: + ... + +ex: Example[int] = test()(reveal_type(Example())) # N: Revealed type is "__main__.Example[[]]" +reveal_type(test()(Example[int]())) # N: Revealed type is "__main__.Example[[builtins.int, builtins.int]]" +ex = test()(Example[int]()) # E: Argument 1 has incompatible type "Example[[int]]"; expected "Example[[]]" +[builtins fixtures/paramspec.pyi] + +[case testRuntimeSpecialParamspecLiteralSyntax] +import sub + +reveal_type(sub.Ex[None]()) # N: Revealed type is "sub.Ex[[None]]" +[file sub/__init__.py] +from typing_extensions import ParamSpec +from typing import Generic + +P = ParamSpec("P") + +class Ex(Generic[P]): ... +[builtins fixtures/paramspec.pyi]