diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index bb92ba65b81e..b3851acf4747 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3108,112 +3108,141 @@ def infer_lambda_type_using_context(self, e: LambdaExpr) -> Tuple[Optional[Calla def visit_super_expr(self, e: SuperExpr) -> Type: """Type check a super expression (non-lvalue).""" - self.check_super_arguments(e) - t = self.analyze_super(e, False) - return t - def check_super_arguments(self, e: SuperExpr) -> None: - """Check arguments in a super(...) call.""" - if ARG_STAR in e.call.arg_kinds: + # We have an expression like super(T, var).member + + # First compute the types of T and var + types = self._super_arg_types(e) + if isinstance(types, tuple): + type_type, instance_type = types + else: + return types + + # Now get the MRO + type_info = type_info_from_type(type_type) + if type_info is None: + self.chk.fail(message_registry.UNSUPPORTED_ARG_1_FOR_SUPER, e) + return AnyType(TypeOfAny.from_error) + + instance_info = type_info_from_type(instance_type) + if instance_info is None: + self.chk.fail(message_registry.UNSUPPORTED_ARG_2_FOR_SUPER, e) + return AnyType(TypeOfAny.from_error) + + mro = instance_info.mro + + # The base is the first MRO entry *after* type_info that has a member + # with the right name + try: + index = mro.index(type_info) + except ValueError: + self.chk.fail(message_registry.SUPER_ARG_2_NOT_INSTANCE_OF_ARG_1, e) + return AnyType(TypeOfAny.from_error) + + for base in mro[index+1:]: + if e.name in base.names or base == mro[-1]: + if e.info and e.info.fallback_to_any and base == mro[-1]: + # There's an undefined base class, and we're at the end of the + # chain. That's not an error. + return AnyType(TypeOfAny.special_form) + + return analyze_member_access(name=e.name, + typ=instance_type, + is_lvalue=False, + is_super=True, + is_operator=False, + original_type=instance_type, + override_info=base, + context=e, + msg=self.msg, + chk=self.chk, + in_literal_context=self.is_literal_context()) + + assert False, 'unreachable' + + def _super_arg_types(self, e: SuperExpr) -> Union[Type, Tuple[Type, Type]]: + """ + Computes the types of the type and instance expressions in super(T, instance), or the + implicit ones for zero-argument super() expressions. Returns a single type for the whole + super expression when possible (for errors, anys), otherwise the pair of computed types. + """ + + if not self.chk.in_checked_function(): + return AnyType(TypeOfAny.unannotated) + elif len(e.call.args) == 0: + if self.chk.options.python_version[0] == 2: + self.chk.fail(message_registry.TOO_FEW_ARGS_FOR_SUPER, e) + return AnyType(TypeOfAny.from_error) + elif not e.info: + # This has already been reported by the semantic analyzer. + return AnyType(TypeOfAny.from_error) + elif self.chk.scope.active_class(): + self.chk.fail(message_registry.SUPER_OUTSIDE_OF_METHOD_NOT_SUPPORTED, e) + return AnyType(TypeOfAny.from_error) + + # Zero-argument super() is like super(, ) + current_type = fill_typevars(e.info) + type_type = TypeType(current_type) # type: Type + + # Use the type of the self argument, in case it was annotated + method = self.chk.scope.top_function() + assert method is not None + if method.arguments: + instance_type = method.arguments[0].variable.type or current_type # type: Type + else: + self.chk.fail(message_registry.SUPER_ENCLOSING_POSITIONAL_ARGS_REQUIRED, e) + return AnyType(TypeOfAny.from_error) + elif ARG_STAR in e.call.arg_kinds: self.chk.fail(message_registry.SUPER_VARARGS_NOT_SUPPORTED, e) - elif e.call.args and set(e.call.arg_kinds) != {ARG_POS}: + return AnyType(TypeOfAny.from_error) + elif set(e.call.arg_kinds) != {ARG_POS}: self.chk.fail(message_registry.SUPER_POSITIONAL_ARGS_REQUIRED, e) + return AnyType(TypeOfAny.from_error) elif len(e.call.args) == 1: self.chk.fail(message_registry.SUPER_WITH_SINGLE_ARG_NOT_SUPPORTED, e) - elif len(e.call.args) > 2: - self.chk.fail(message_registry.TOO_MANY_ARGS_FOR_SUPER, e) - elif self.chk.options.python_version[0] == 2 and len(e.call.args) == 0: - self.chk.fail(message_registry.TOO_FEW_ARGS_FOR_SUPER, e) + return AnyType(TypeOfAny.from_error) elif len(e.call.args) == 2: - type_obj_type = self.accept(e.call.args[0]) + type_type = self.accept(e.call.args[0]) instance_type = self.accept(e.call.args[1]) - if isinstance(type_obj_type, FunctionLike) and type_obj_type.is_type_obj(): - type_info = type_obj_type.type_object() - elif isinstance(type_obj_type, TypeType): - item = type_obj_type.item - if isinstance(item, AnyType): - # Could be anything. - return - if isinstance(item, TupleType): - # Handle named tuples and other Tuple[...] subclasses. - item = tuple_fallback(item) - if not isinstance(item, Instance): - # A complicated type object type. Too tricky, give up. - # TODO: Do something more clever here. - self.chk.fail(message_registry.UNSUPPORTED_ARG_1_FOR_SUPER, e) - return - type_info = item.type - elif isinstance(type_obj_type, AnyType): - return + else: + self.chk.fail(message_registry.TOO_MANY_ARGS_FOR_SUPER, e) + return AnyType(TypeOfAny.from_error) + + # Imprecisely assume that the type is the current class + if isinstance(type_type, AnyType): + if e.info: + type_type = TypeType(fill_typevars(e.info)) else: - self.msg.first_argument_for_super_must_be_type(type_obj_type, e) - return + return AnyType(TypeOfAny.from_another_any, source_any=type_type) + elif isinstance(type_type, TypeType): + type_item = type_type.item + if isinstance(type_item, AnyType): + if e.info: + type_type = TypeType(fill_typevars(e.info)) + else: + return AnyType(TypeOfAny.from_another_any, source_any=type_item) - if isinstance(instance_type, (Instance, TupleType, TypeVarType)): - if isinstance(instance_type, TypeVarType): - # Needed for generic self. - instance_type = instance_type.upper_bound - if not isinstance(instance_type, (Instance, TupleType)): - # Too tricky, give up. - # TODO: Do something more clever here. - self.chk.fail(message_registry.UNSUPPORTED_ARG_2_FOR_SUPER, e) - return - if isinstance(instance_type, TupleType): - # Needed for named tuples and other Tuple[...] subclasses. - instance_type = tuple_fallback(instance_type) - if type_info not in instance_type.type.mro: - self.chk.fail(message_registry.SUPER_ARG_2_NOT_INSTANCE_OF_ARG_1, e) - elif isinstance(instance_type, TypeType) or (isinstance(instance_type, FunctionLike) - and instance_type.is_type_obj()): - # TODO: Check whether this is a valid type object here. - pass - elif not isinstance(instance_type, AnyType): - self.chk.fail(message_registry.UNSUPPORTED_ARG_2_FOR_SUPER, e) - - def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type: - """Type check a super expression.""" - if e.info and e.info.bases: - # TODO fix multiple inheritance etc - if len(e.info.mro) < 2: - self.chk.fail('Internal error: unexpected mro for {}: {}'.format( - e.info.name(), e.info.mro), e) - return AnyType(TypeOfAny.from_error) - for base in e.info.mro[1:]: - if e.name in base.names or base == e.info.mro[-1]: - if e.info.fallback_to_any and base == e.info.mro[-1]: - # There's an undefined base class, and we're - # at the end of the chain. That's not an error. - return AnyType(TypeOfAny.special_form) - if not self.chk.in_checked_function(): - return AnyType(TypeOfAny.unannotated) - if self.chk.scope.active_class() is not None: - self.chk.fail(message_registry.SUPER_OUTSIDE_OF_METHOD_NOT_SUPPORTED, e) - return AnyType(TypeOfAny.from_error) - method = self.chk.scope.top_function() - assert method is not None - args = method.arguments - # super() in a function with empty args is an error; we - # need something in declared_self. - if not args: - self.chk.fail(message_registry.SUPER_ENCLOSING_POSITIONAL_ARGS_REQUIRED, e) - return AnyType(TypeOfAny.from_error) - declared_self = args[0].variable.type or fill_typevars(e.info) - return analyze_member_access(name=e.name, - typ=fill_typevars(e.info), - is_lvalue=False, - is_super=True, - is_operator=False, - original_type=declared_self, - override_info=base, - context=e, - msg=self.msg, - chk=self.chk, - in_literal_context=self.is_literal_context()) - assert False, 'unreachable' - else: - # Invalid super. This has been reported by the semantic analyzer. + if (not isinstance(type_type, TypeType) + and not (isinstance(type_type, FunctionLike) and type_type.is_type_obj())): + self.msg.first_argument_for_super_must_be_type(type_type, e) return AnyType(TypeOfAny.from_error) + # Imprecisely assume that the instance is of the current class + if isinstance(instance_type, AnyType): + if e.info: + instance_type = fill_typevars(e.info) + else: + return AnyType(TypeOfAny.from_another_any, source_any=instance_type) + elif isinstance(instance_type, TypeType): + instance_item = instance_type.item + if isinstance(instance_item, AnyType): + if e.info: + instance_type = TypeType(fill_typevars(e.info)) + else: + return AnyType(TypeOfAny.from_another_any, source_any=instance_item) + + return type_type, instance_type + def visit_slice_expr(self, e: SliceExpr) -> Type: expected = make_optional_type(self.named_type('builtins.int')) for index in [e.begin_index, e.end_index, e.stride]: @@ -4001,3 +4030,22 @@ def has_bytes_component(typ: Type) -> bool: if isinstance(typ, Instance) and typ.type.fullname() == 'builtins.bytes': return True return False + + +def type_info_from_type(typ: Type) -> Optional[TypeInfo]: + """Gets the TypeInfo for a type, indirecting through things like type variables and tuples.""" + + if isinstance(typ, FunctionLike) and typ.is_type_obj(): + return typ.type_object() + if isinstance(typ, TypeType): + typ = typ.item + if isinstance(typ, TypeVarType): + typ = typ.upper_bound + if isinstance(typ, TupleType): + typ = tuple_fallback(typ) + if isinstance(typ, Instance): + return typ.type + + # A complicated type. Too tricky, give up. + # TODO: Do something more clever here. + return None diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 7474f7f5bd8d..50adb3973cb7 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -136,17 +136,17 @@ def _analyze_member_access(name: str, elif isinstance(typ, FunctionLike) and typ.is_type_obj(): return analyze_type_callable_member_access(name, typ, mx) elif isinstance(typ, TypeType): - return analyze_type_type_member_access(name, typ, mx) + return analyze_type_type_member_access(name, typ, mx, override_info) elif isinstance(typ, TupleType): # Actually look up from the fallback instance type. - return _analyze_member_access(name, tuple_fallback(typ), mx) + return _analyze_member_access(name, tuple_fallback(typ), mx, override_info) elif isinstance(typ, (TypedDictType, LiteralType, FunctionLike)): # Actually look up from the fallback instance type. - return _analyze_member_access(name, typ.fallback, mx) + return _analyze_member_access(name, typ.fallback, mx, override_info) elif isinstance(typ, NoneType): return analyze_none_member_access(name, typ, mx) elif isinstance(typ, TypeVarType): - return _analyze_member_access(name, typ.upper_bound, mx) + return _analyze_member_access(name, typ.upper_bound, mx, override_info) elif isinstance(typ, DeletedType): mx.msg.deleted_as_rvalue(typ, mx.context) return AnyType(TypeOfAny.from_error) @@ -238,7 +238,10 @@ def analyze_type_callable_member_access(name: str, assert False, 'Unexpected type {}'.format(repr(ret_type)) -def analyze_type_type_member_access(name: str, typ: TypeType, mx: MemberContext) -> Type: +def analyze_type_type_member_access(name: str, + typ: TypeType, + mx: MemberContext, + override_info: Optional[TypeInfo]) -> Type: # Similar to analyze_type_callable_attribute_access. item = None fallback = mx.builtin_type('builtins.type') @@ -248,7 +251,7 @@ def analyze_type_type_member_access(name: str, typ: TypeType, mx: MemberContext) item = typ.item elif isinstance(typ.item, AnyType): mx = mx.copy_modified(messages=ignore_messages) - return _analyze_member_access(name, fallback, mx) + return _analyze_member_access(name, fallback, mx, override_info) elif isinstance(typ.item, TypeVarType): if isinstance(typ.item.upper_bound, Instance): item = typ.item.upper_bound @@ -262,7 +265,7 @@ def analyze_type_type_member_access(name: str, typ: TypeType, mx: MemberContext) item = typ.item.item.type.metaclass_type if item and not mx.is_operator: # See comment above for why operators are skipped - result = analyze_class_attribute_access(item, name, mx) + result = analyze_class_attribute_access(item, name, mx, override_info) if result: if not (isinstance(result, AnyType) and item.type.fallback_to_any): return result @@ -271,7 +274,7 @@ def analyze_type_type_member_access(name: str, typ: TypeType, mx: MemberContext) mx = mx.copy_modified(messages=ignore_messages) if item is not None: fallback = item.type.metaclass_type or fallback - return _analyze_member_access(name, fallback, mx) + return _analyze_member_access(name, fallback, mx, override_info) def analyze_union_member_access(name: str, typ: UnionType, mx: MemberContext) -> Type: @@ -603,11 +606,16 @@ class A: def analyze_class_attribute_access(itype: Instance, name: str, - mx: MemberContext) -> Optional[Type]: + mx: MemberContext, + override_info: Optional[TypeInfo] = None) -> Optional[Type]: """original_type is the type of E in the expression E.var""" - node = itype.type.get(name) + info = itype.type + if override_info: + info = override_info + + node = info.get(name) if not node: - if itype.type.fallback_to_any: + if info.fallback_to_any: return AnyType(TypeOfAny.special_form) return None @@ -628,9 +636,9 @@ def analyze_class_attribute_access(itype: Instance, # An assignment to final attribute on class object is also always an error, # independently of types. if mx.is_lvalue and not mx.chk.get_final_context(): - check_final_member(name, itype.type, mx.msg, mx.context) + check_final_member(name, info, mx.msg, mx.context) - if itype.type.is_enum and not (mx.is_lvalue or is_decorated or is_method): + if info.is_enum and not (mx.is_lvalue or is_decorated or is_method): enum_literal = LiteralType(name, fallback=itype) return itype.copy_modified(last_known_value=enum_literal) @@ -691,7 +699,7 @@ def analyze_class_attribute_access(itype: Instance, if isinstance(node.node, TypeVarExpr): mx.msg.fail(message_registry.CANNOT_USE_TYPEVAR_AS_EXPRESSION.format( - itype.type.name(), name), mx.context) + info.name(), name), mx.context) return AnyType(TypeOfAny.from_error) if isinstance(node.node, TypeInfo): diff --git a/mypy/newsemanal/semanal.py b/mypy/newsemanal/semanal.py index edb9112bdc9c..9159e59c8924 100644 --- a/mypy/newsemanal/semanal.py +++ b/mypy/newsemanal/semanal.py @@ -3327,7 +3327,7 @@ def bind_name_expr(self, expr: NameExpr, sym: SymbolTableNode) -> None: expr.fullname = sym.fullname def visit_super_expr(self, expr: SuperExpr) -> None: - if not self.type: + if not self.type and not expr.call.args: self.fail('"super" used outside class', expr) return expr.info = self.type diff --git a/mypy/semanal.py b/mypy/semanal.py index 4f0dc73e579b..0806391f5e1e 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -2956,7 +2956,7 @@ def visit_name_expr(self, expr: NameExpr) -> None: expr.fullname = n.fullname def visit_super_expr(self, expr: SuperExpr) -> None: - if not self.type: + if not self.type and not expr.call.args: self.fail('"super" used outside class', expr) return expr.info = self.type diff --git a/test-data/unit/check-class-namedtuple.test b/test-data/unit/check-class-namedtuple.test index 38e60f45f50c..d9e4061e6aac 100644 --- a/test-data/unit/check-class-namedtuple.test +++ b/test-data/unit/check-class-namedtuple.test @@ -493,7 +493,7 @@ Y(y=1, x='1').method() class CallsBaseInit(X): def __init__(self, x: str) -> None: - super().__init__(x) + super().__init__(x) # E: Too many arguments for "__init__" of "object" [case testNewNamedTupleWithMethods] from typing import NamedTuple diff --git a/test-data/unit/check-super.test b/test-data/unit/check-super.test index e8c1a2721e58..54ee5841fe10 100644 --- a/test-data/unit/check-super.test +++ b/test-data/unit/check-super.test @@ -121,6 +121,7 @@ class C(B): def h(self, x) -> None: reveal_type(super(x, x).f) # N: Revealed type is 'def ()' reveal_type(super(C, x).f) # N: Revealed type is 'def ()' + reveal_type(super(C, type(x)).f) # N: Revealed type is 'def (self: __main__.B)' [case testSuperInUnannotatedMethod] class C: @@ -204,6 +205,53 @@ class C(B, Generic[T, S]): class D(C): pass +[case testSuperInClassMethod] +from typing import Union + +class A: + def f(self, i: int) -> None: pass + +class B(A): + def f(self, i: Union[int, str]) -> None: pass + + @classmethod + def g(cls, i: int) -> None: + super().f(B(), i) + super(B, cls).f(cls(), i) + super(B, B()).f(i) + + super().f(B(), '') # E: Argument 2 to "f" of "A" has incompatible type "str"; expected "int" + super(B, cls).f(cls(), '') # E: Argument 2 to "f" of "A" has incompatible type "str"; expected "int" + super(B, B()).f('') # E: Argument 1 to "f" of "A" has incompatible type "str"; expected "int" +[builtins fixtures/classmethod.pyi] + +[case testSuperWithUnrelatedTypes] +from typing import Union + +class A: + def f(self, s: str) -> None: pass + +class B(A): + def f(self, i: Union[int, str]) -> None: pass + +class C: + def g(self, b: B) -> None: + super(B, b).f('42') + super(B, b).f(42) # E: Argument 1 to "f" of "A" has incompatible type "int"; expected "str" + +[case testSuperOutsideClass] +from typing import Union + +class A: + def f(self, s: str) -> None: pass + +class B(A): + def f(self, i: Union[int, str]) -> None: pass + +def g(b: B) -> None: + super(B, b).f('42') + super(B, b).f(42) # E: Argument 1 to "f" of "A" has incompatible type "int"; expected "str" + -- Invalid uses of super() -- -----------------------