diff --git a/mypy/typeanal.py b/mypy/typeanal.py index f7b584eadae8..e0e8eca26587 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -162,6 +162,8 @@ def __init__(self, self.is_typeshed_stub = is_typeshed_stub # Names of type aliases encountered while analysing a type will be collected here. self.aliases_used: Set[str] = set() + # Is needed for `TypeGuard` analysis: + self.is_return_type = False def visit_unbound_type(self, t: UnboundType, defining_literal: bool = False) -> Type: typ = self.visit_unbound_type_nonoptional(t, defining_literal) @@ -562,7 +564,7 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type: variables = t.variables else: variables = self.bind_function_type_variables(t, t) - special = self.anal_type_guard(t.ret_type) + special = self.anal_type_guard(t.ret_type, explicit=True) arg_kinds = t.arg_kinds if len(arg_kinds) >= 2 and arg_kinds[-2] == ARG_STAR and arg_kinds[-1] == ARG_STAR2: arg_types = self.anal_array(t.arg_types[:-2], nested=nested) + [ @@ -571,30 +573,35 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type: ] else: arg_types = self.anal_array(t.arg_types, nested=nested) - ret = t.copy_modified(arg_types=arg_types, - ret_type=self.anal_type(t.ret_type, nested=nested), - # If the fallback isn't filled in yet, - # its type will be the falsey FakeInfo - fallback=(t.fallback if t.fallback.type - else self.named_type('builtins.function')), - variables=self.anal_var_defs(variables), - type_guard=special, - ) + ret = t.copy_modified( + arg_types=arg_types, + ret_type=self.anal_type(t.ret_type, nested=nested, is_return_type=True), + # If the fallback isn't filled in yet, + # its type will be the falsey FakeInfo + fallback=(t.fallback if t.fallback.type + else self.named_type('builtins.function')), + variables=self.anal_var_defs(variables), + type_guard=special, + ) return ret - def anal_type_guard(self, t: Type) -> Optional[Type]: + def anal_type_guard(self, t: Type, *, explicit: bool = False) -> Optional[Type]: if isinstance(t, UnboundType): sym = self.lookup_qualified(t.name, t) if sym is not None and sym.node is not None: - return self.anal_type_guard_arg(t, sym.node.fullname) + return self.anal_type_guard_arg(t, sym.node.fullname, explicit=explicit) # TODO: What if it's an Instance? Then use t.type.fullname? return None - def anal_type_guard_arg(self, t: UnboundType, fullname: str) -> Optional[Type]: + def anal_type_guard_arg( + self, t: UnboundType, fullname: str, *, explicit: bool = False, + ) -> Optional[Type]: if fullname in ('typing_extensions.TypeGuard', 'typing.TypeGuard'): if len(t.args) != 1: self.fail("TypeGuard must have exactly one type argument", t) return AnyType(TypeOfAny.from_error) + if not (explicit or self.is_return_type): + self.fail('TypeGuard must only be used as a return type', t) return self.anal_type(t.args[0]) return None @@ -1011,17 +1018,26 @@ def anal_array(self, res.append(self.anal_type(t, nested, allow_param_spec=allow_param_spec)) return res - def anal_type(self, t: Type, nested: bool = True, *, allow_param_spec: bool = False) -> Type: + def anal_type( + self, + t: Type, + nested: bool = True, + *, + allow_param_spec: bool = False, + is_return_type: bool = False, + ) -> Type: if nested: self.nesting_level += 1 old_allow_required = self.allow_required self.allow_required = False + self.is_return_type = is_return_type try: analyzed = t.accept(self) finally: if nested: self.nesting_level -= 1 self.allow_required = old_allow_required + self.is_return_type = False if (not allow_param_spec and isinstance(analyzed, ParamSpecType) and analyzed.flavor == ParamSpecFlavor.BARE): diff --git a/test-data/unit/check-typeguard.test b/test-data/unit/check-typeguard.test index 32fe5e750989..40e7802c26c8 100644 --- a/test-data/unit/check-typeguard.test +++ b/test-data/unit/check-typeguard.test @@ -55,13 +55,18 @@ def main(a: object, b: object) -> None: [builtins fixtures/tuple.pyi] [case testTypeGuardIsBool] +from typing import List from typing_extensions import TypeGuard -def f(a: TypeGuard[int]) -> None: pass +def f(a: TypeGuard[int]) -> None: # E: TypeGuard must only be used as a return type + pass reveal_type(f) # N: Revealed type is "def (a: builtins.bool)" -a: TypeGuard[int] +def f1() -> List[TypeGuard[int]]: # E: TypeGuard must only be used as a return type + pass +reveal_type(f1()) # N: Revealed type is "builtins.list[builtins.bool]" +a: TypeGuard[int] # E: TypeGuard must only be used as a return type reveal_type(a) # N: Revealed type is "builtins.bool" class C: - a: TypeGuard[int] + a: TypeGuard[int] # E: TypeGuard must only be used as a return type reveal_type(C().a) # N: Revealed type is "builtins.bool" [builtins fixtures/tuple.pyi]