diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index d545b39e7f19..37bc031f95dc 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -88,23 +88,6 @@ def transform(self) -> None: args=[attr.to_argument(info) for attr in attributes if attr.is_in_init], return_type=NoneTyp(), ) - for stmt in self._ctx.cls.defs.body: - # Fix up the types of classmethods since, by default, - # they will be based on the parent class' init. - if isinstance(stmt, Decorator) and stmt.func.is_class: - func_type = stmt.func.type - if isinstance(func_type, CallableType): - func_type.arg_types[0] = self._ctx.api.class_type(self._ctx.cls.info) - if isinstance(stmt, OverloadedFuncDef) and stmt.is_class: - func_type = stmt.type - if isinstance(func_type, Overloaded): - class_type = ctx.api.class_type(ctx.cls.info) - for item in func_type.items(): - item.arg_types[0] = class_type - if stmt.impl is not None: - assert isinstance(stmt.impl, Decorator) - if isinstance(stmt.impl.func.type, CallableType): - stmt.impl.func.type.arg_types[0] = class_type # Add an eq method, but only if the class doesn't already have one. if decorator_arguments['eq'] and info.get('__eq__') is None: diff --git a/mypy/semanal.py b/mypy/semanal.py index b7d9942f4243..049ace77024e 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -214,6 +214,10 @@ class SemanticAnalyzerPass2(NodeVisitor[None], # postpone_nested_functions_stack[-1] == FUNCTION_FIRST_PHASE_POSTPONE_SECOND. postponed_functions_stack = None # type: List[List[Node]] + # Classmethod definitions that couldn't be prepared yet + # until the class analysis is complete + postponed_classmethods_stack = None # type: List[List[FuncDef]] + loop_depth = 0 # Depth of breakable loops cur_mod_id = '' # Current module id (or None) (phase 2) is_stub_file = False # Are we analyzing a stub file? @@ -246,6 +250,7 @@ def __init__(self, self.missing_modules = missing_modules self.postpone_nested_functions_stack = [FUNCTION_BOTH_PHASES] self.postponed_functions_stack = [] + self.postponed_classmethods_stack = [] self.all_exports = set() # type: Set[str] self.plugin = plugin # If True, process function definitions. If False, don't. This is used @@ -456,6 +461,14 @@ def _visit_func_def(self, defn: FuncDef) -> None: assert ret_type is not None, "Internal error: typing.Coroutine not found" defn.type = defn.type.copy_modified(ret_type=ret_type) + def prepare_postponed_classmethods_signature(self) -> None: + for func in self.postponed_classmethods_stack[-1]: + functype = func.type + assert isinstance(functype, CallableType) + assert self.type, "Classmethod singatures preparation outside of a class" + leading_type = self.class_type(self.type) + func.type = replace_implicit_first_type(functype, leading_type) + def prepare_method_signature(self, func: FuncDef, info: TypeInfo) -> None: """Check basic signature validity and tweak annotation of self/cls argument.""" # Only non-static methods are special. @@ -467,10 +480,10 @@ def prepare_method_signature(self, func: FuncDef, info: TypeInfo) -> None: self_type = functype.arg_types[0] if isinstance(self_type, AnyType): if func.is_class or func.name() in ('__new__', '__init_subclass__'): - leading_type = self.class_type(info) + self.postponed_classmethods_stack[-1].append(func) else: leading_type = fill_typevars(info) - func.type = replace_implicit_first_type(functype, leading_type) + func.type = replace_implicit_first_type(functype, leading_type) def set_original_def(self, previous: Optional[Node], new: FuncDef) -> bool: """If 'new' conditionally redefine 'previous', set 'previous' as original @@ -834,6 +847,7 @@ def enter_class(self, info: TypeInfo) -> None: self.locals.append(None) # Add class scope self.block_depth.append(-1) # The class body increments this to 0 self.postpone_nested_functions_stack.append(FUNCTION_BOTH_PHASES) + self.postponed_classmethods_stack.append([]) self.type = info def leave_class(self) -> None: @@ -841,6 +855,8 @@ def leave_class(self) -> None: self.postpone_nested_functions_stack.pop() self.block_depth.pop() self.locals.pop() + self.prepare_postponed_classmethods_signature() + self.postponed_classmethods_stack.pop() self.type = self.type_stack.pop() def analyze_class_decorator(self, defn: ClassDef, decorator: Expression) -> None: diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 5216150dfa8b..774964bbafe3 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -616,6 +616,62 @@ main:3: error: Incompatible types in assignment (expression has type "B", variab main:4: error: Incompatible types in assignment (expression has type "A", variable has type "D") main:5: error: Incompatible types in assignment (expression has type "D2", variable has type "D") +[case testClassmethodBeforeInit] +class C: + @classmethod + def create(cls, arg: int) -> 'C': + reveal_type(cls) # E: Revealed type is 'def (arg: builtins.int) -> __main__.C' + reveal_type(cls(arg)) # E: Revealed type is '__main__.C' + cls() # E: Too few arguments for "C" + return cls(arg) + + def __init__(self, arg: 'int') -> None: + self._arg: int = arg + +reveal_type(C.create) # E: Revealed type is 'def (arg: builtins.int) -> __main__.C' +[builtins fixtures/classmethod.pyi] + +[case testClassmethodBeforeInitInNestedClass] +class A: + @classmethod + def create(cls, arg: int) -> 'A': + reveal_type(cls) # E: Revealed type is 'def (arg: builtins.int) -> __main__.A' + reveal_type(cls(arg)) # E: Revealed type is '__main__.A' + cls() # E: Too few arguments for "A" + return cls(arg) + + class B: + @classmethod + def create(cls, arg: str) -> 'A.B': + reveal_type(cls) # E: Revealed type is 'def (arg: builtins.str) -> __main__.A.B' + reveal_type(cls(arg)) # E: Revealed type is '__main__.A.B' + cls() # E: Too few arguments for "B" + return cls(arg) + + def __init__(self, arg: 'str') -> None: + self._arg: str = arg + + def __init__(self, arg: 'int') -> None: + self._arg: int = arg + +reveal_type(A.create) # E: Revealed type is 'def (arg: builtins.int) -> __main__.A' +reveal_type(A.B.create) # E: Revealed type is 'def (arg: builtins.str) -> __main__.A.B' +[builtins fixtures/classmethod.pyi] + +[case testClassmethodBeforeInitInGenericClass] +from typing import Generic, TypeVar +T = TypeVar('T') +class C(Generic[T]): + @classmethod + def create(cls, arg: T) -> 'C': + cls() # E: Too few arguments for "C" + return cls(arg, arg) + + def __init__(self, arg1: T, arg2: T) -> None: + self._arg1 = arg1 + self._arg2 = arg2 +[builtins fixtures/classmethod.pyi] + -- Attribute access in class body -- ------------------------------ diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index ef26f7f91484..c3ed0ca5370d 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -523,6 +523,21 @@ class A: def __init__(self, b: 'B') -> None: pass class B: pass +[case testOverloadedInitAndClassMethod] +from foo import * +[file foo.pyi] +from typing import overload +class A: + @classmethod + def from_int(cls, x: int) -> 'A': + reveal_type(cls) # E: Revealed type is 'Any' + return cls(x) + @overload + def __init__(self, a: int) -> None: pass + @overload + def __init__(self, b: str) -> None: pass +[builtins fixtures/classmethod.pyi] + [case testIntersectionTypeCompatibility] from foo import * [file foo.pyi] diff --git a/test-data/unit/deps.test b/test-data/unit/deps.test index 080b41693e2f..168700f8c6dc 100644 --- a/test-data/unit/deps.test +++ b/test-data/unit/deps.test @@ -29,6 +29,19 @@ class A: -> m.f -> , m.A, m.f +[case testCallClassMethod] +def f() -> None: + A.g() +class A: + @classmethod + def g(cls) -> None: pass +[out] + -> m.f + -> m.f + -> m, m.f + -> m.A, m.f +[builtins fixtures/classmethod.pyi] + [case testAccessAttribute] def f(a: A) -> None: a.x diff --git a/test-data/unit/fine-grained.test b/test-data/unit/fine-grained.test index 43233f6d5764..87581b8a7bb9 100644 --- a/test-data/unit/fine-grained.test +++ b/test-data/unit/fine-grained.test @@ -2143,6 +2143,51 @@ main:3: note: Possible overload variants: main:3: note: def foo(cls: Wrapper, x: int) -> int main:3: note: def foo(cls: Wrapper, x: str) -> str +[case testInitMovesAroundClassmethod] +[file m.py] +class C: + def __init__(self, x: int) -> None: + self._x = x + @classmethod + def create(cls, x: int) -> 'C': + return cls(x) +[file m.py.2] +class C: + @classmethod + def create(cls, x: int) -> 'C': + return cls(x) + def __init__(self, x: int) -> None: + self._x = x +[file m.py.3] +class C: + def __init__(self, x: int) -> None: + self._x = x + @classmethod + def create(cls, x: int) -> 'C': + return cls(x) +[builtins fixtures/classmethod.pyi] +[out] +== +== + +[case testInitDisappearsClassmethodStays] +[file m.py] +class C: + def __init__(self, x: int) -> None: + self._x = x + @classmethod + def create(cls, x: int) -> 'C': + return cls(x) +[file m.py.2] +class C: + @classmethod + def create(cls, x: int) -> 'C': + return cls(x) +[builtins fixtures/classmethod.pyi] +[out] +== +m.py:4: error: Too many arguments for "C" + [case testRefreshGenericClass] from typing import TypeVar, Generic from a import A