Skip to content

Commit d30137c

Browse files
committed
Support forward references to decorated functions
This handles a variety of common cases of forward references to decorated functions by effectively performing limited type inference during semantic analysis. Also fixes forward references to properties. Closes #613. Closes #857.
1 parent cabf101 commit d30137c

File tree

4 files changed

+293
-17
lines changed

4 files changed

+293
-17
lines changed

mypy/build.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,9 @@ def __init__(self, data_dir: str,
339339
self.reports = reports
340340
self.semantic_analyzer = SemanticAnalyzer(lib_path, self.errors,
341341
pyversion=pyversion)
342-
self.semantic_analyzer_pass3 = ThirdPass(self.errors)
343-
self.type_checker = TypeChecker(self.errors,
344-
self.semantic_analyzer.modules,
345-
self.pyversion)
342+
modules = self.semantic_analyzer.modules
343+
self.semantic_analyzer_pass3 = ThirdPass(modules, self.errors)
344+
self.type_checker = TypeChecker(self.errors, modules, self.pyversion)
346345
self.states = [] # type: List[State]
347346
self.module_files = {} # type: Dict[str, str]
348347
self.module_deps = {} # type: Dict[Tuple[str, str], bool]

mypy/semanal.py

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,6 +1357,7 @@ def visit_decorator(self, dec: Decorator) -> None:
13571357
removed = [] # type: List[int]
13581358
no_type_check = False
13591359
for i, d in enumerate(dec.decorators):
1360+
# A bunch of decorators are special cased here.
13601361
if refers_to_fullname(d, 'abc.abstractmethod'):
13611362
removed.append(i)
13621363
dec.func.is_abstract = True
@@ -1395,14 +1396,10 @@ def visit_decorator(self, dec: Decorator) -> None:
13951396
dec.var.is_initialized_in_class = True
13961397
self.add_symbol(dec.var.name(), SymbolTableNode(MDEF, dec),
13971398
dec)
1398-
if dec.decorators and dec.var.is_property:
1399-
self.fail('Decorated property not supported', dec)
14001399
if not no_type_check:
14011400
dec.func.accept(self)
1402-
if not dec.decorators and not dec.var.is_property:
1403-
# No non-special decorators left. We can trivially infer the type
1404-
# of the function here.
1405-
dec.var.type = dec.func.type
1401+
if dec.decorators and dec.var.is_property:
1402+
self.fail('Decorated property not supported', dec)
14061403

14071404
def check_decorated_function_is_method(self, decorator: str,
14081405
context: Context) -> None:
@@ -2162,10 +2159,12 @@ def visit_try_stmt(self, s: TryStmt) -> None:
21622159
class ThirdPass(TraverserVisitor[None]):
21632160
"""The third and final pass of semantic analysis.
21642161
2165-
Check type argument counts and values of generic types.
2162+
Check type argument counts and values of generic types, and perform some
2163+
straightforward type inference.
21662164
"""
21672165

2168-
def __init__(self, errors: Errors) -> None:
2166+
def __init__(self, modules: Dict[str, MypyFile], errors: Errors) -> None:
2167+
self.modules = modules
21692168
self.errors = errors
21702169

21712170
def visit_file(self, file_node: MypyFile, fnam: str) -> None:
@@ -2183,6 +2182,37 @@ def visit_class_def(self, tdef: ClassDef) -> None:
21832182
self.analyze(type)
21842183
super().visit_class_def(tdef)
21852184

2185+
def visit_decorator(self, dec: Decorator) -> None:
2186+
"""Try to infer the type of the decorated function.
2187+
2188+
This helps us resolve forward references to decorated
2189+
functions during type checking.
2190+
"""
2191+
super().visit_decorator(dec)
2192+
if dec.var.is_property:
2193+
if dec.func.type is None:
2194+
dec.var.type = AnyType()
2195+
elif isinstance(dec.func.type, CallableType):
2196+
dec.var.type = dec.func.type.ret_type
2197+
return
2198+
decorator_preserves_type = True
2199+
for expr in dec.decorators:
2200+
preserve_type = False
2201+
if isinstance(expr, RefExpr) and isinstance(expr.node, FuncDef):
2202+
if is_identity_signature(expr.node.type):
2203+
preserve_type = True
2204+
if not preserve_type:
2205+
decorator_preserves_type = False
2206+
break
2207+
if decorator_preserves_type:
2208+
# No non-special decorators left. We can trivially infer the type
2209+
# of the function here.
2210+
dec.var.type = function_type(dec.func, self.builtin_type('function'))
2211+
if dec.decorators and returns_any_if_called(dec.decorators[0]):
2212+
# The outermost decorator will return Any so we know the type of the
2213+
# decorated function.
2214+
dec.var.type = AnyType()
2215+
21862216
def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
21872217
self.analyze(s.type)
21882218
super().visit_assignment_stmt(s)
@@ -2196,6 +2226,8 @@ def visit_type_application(self, e: TypeApplication) -> None:
21962226
self.analyze(type)
21972227
super().visit_type_application(e)
21982228

2229+
# Helpers
2230+
21992231
def analyze(self, type: Type) -> None:
22002232
if type:
22012233
analyzer = TypeAnalyserPass3(self.fail)
@@ -2204,6 +2236,12 @@ def analyze(self, type: Type) -> None:
22042236
def fail(self, msg: str, ctx: Context) -> None:
22052237
self.errors.report(ctx.get_line(), msg)
22062238

2239+
def builtin_type(self, name: str, args: List[Type] = None) -> Instance:
2240+
names = self.modules['builtins']
2241+
sym = names.names[name]
2242+
assert isinstance(sym.node, TypeInfo)
2243+
return Instance(cast(TypeInfo, sym.node), args or [])
2244+
22072245

22082246
def self_type(typ: TypeInfo) -> Union[Instance, TupleType]:
22092247
"""For a non-generic type, return instance type representing the type.
@@ -2356,3 +2394,34 @@ def visit_import_from(self, node: ImportFrom) -> None:
23562394

23572395
def visit_import_all(self, node: ImportAll) -> None:
23582396
node.is_unreachable = True
2397+
2398+
2399+
def is_identity_signature(sig: Type) -> bool:
2400+
"""Is type a callable of form T -> T (where T is a type variable)?"""
2401+
if isinstance(sig, CallableType) and sig.arg_kinds == [ARG_POS]:
2402+
if isinstance(sig.arg_types[0], TypeVarType) and isinstance(sig.ret_type, TypeVarType):
2403+
return sig.arg_types[0].id == sig.ret_type.id
2404+
return False
2405+
2406+
2407+
def returns_any_if_called(expr: Node) -> bool:
2408+
"""Return True if we can predict that expr will return Any if called.
2409+
2410+
This only uses information available during semantic analysis so this
2411+
will sometimes return False because of insufficient information (as
2412+
type inference hasn't run yet).
2413+
"""
2414+
if isinstance(expr, RefExpr):
2415+
if isinstance(expr.node, FuncDef):
2416+
typ = expr.node.type
2417+
if typ is None:
2418+
# No signature -> default to Any.
2419+
return True
2420+
# Explicit Any return?
2421+
return isinstance(typ, CallableType) and isinstance(typ.ret_type, AnyType)
2422+
elif isinstance(expr.node, Var):
2423+
typ = expr.node.type
2424+
return typ is None or isinstance(typ, AnyType)
2425+
elif isinstance(expr, CallExpr):
2426+
return returns_any_if_called(expr.callee)
2427+
return False

mypy/test/data/check-functions.test

Lines changed: 209 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,6 @@ def dec2(f: Callable[[Any, Any], None]) -> Callable[[Any], None]: pass
559559
@dec2
560560
def f(x, y): pass
561561

562-
563562
[case testNoTypeCheckDecoratorOnMethod1]
564563
from typing import no_type_check
565564

@@ -606,6 +605,214 @@ def foo(x: {1:2}) -> [1]:
606605
x = y
607606

608607

608+
-- Forward references to decorated functions
609+
-- -----------------------------------------
610+
611+
612+
[case testForwardReferenceToDynamicallyTypedDecorator]
613+
def f(self) -> None:
614+
g()
615+
g(1)
616+
617+
def dec(f):
618+
return f
619+
620+
@dec
621+
def g():
622+
pass
623+
624+
[case testForwardReferenceToDecoratorWithAnyReturn]
625+
from typing import Any
626+
627+
def f(self) -> None:
628+
g()
629+
g(1)
630+
631+
def dec(f) -> Any:
632+
return f
633+
634+
@dec
635+
def g():
636+
pass
637+
638+
[case testForwardReferenceToDecoratorWithIdentityMapping]
639+
from typing import TypeVar
640+
641+
def f(self) -> None:
642+
g()
643+
g(1) # E: Too many arguments for "g"
644+
h(1).x # E: "str" has no attribute "x"
645+
h('') # E: Argument 1 to "h" has incompatible type "str"; expected "int"
646+
647+
T = TypeVar('T')
648+
def dec(f: T) -> T:
649+
return f
650+
651+
@dec
652+
def g(): pass
653+
@dec
654+
def h(x: int) -> str: pass
655+
[out]
656+
main: note: In function "f":
657+
658+
[case testForwardReferenceToDynamicallyTypedDecoratedMethod]
659+
def f(self) -> None:
660+
A().f(1).y
661+
A().f()
662+
663+
class A:
664+
@dec
665+
def f(self, x): pass
666+
667+
def dec(f): return f
668+
[builtins fixtures/staticmethod.py]
669+
670+
[case testForwardReferenceToStaticallyTypedDecoratedMethod]
671+
from typing import TypeVar
672+
673+
def f(self) -> None:
674+
A().f(1).y # E: "str" has no attribute "y"
675+
A().f('') # E: Argument 1 to "f" of "A" has incompatible type "str"; expected "int"
676+
677+
class A:
678+
@dec
679+
def f(self, a: int) -> str: return ''
680+
681+
T = TypeVar('T')
682+
def dec(f: T) -> T: return f
683+
[builtins fixtures/staticmethod.py]
684+
[out]
685+
main: note: In function "f":
686+
687+
[case testForwardReferenceToDynamicallyTypedProperty]
688+
def f(self) -> None:
689+
A().x.y
690+
691+
class A:
692+
@property
693+
def x(self): pass
694+
[builtins fixtures/property.py]
695+
696+
[case testForwardReferenceToStaticallyTypedProperty]
697+
def f(self) -> None:
698+
A().x.y # E: "int" has no attribute "y"
699+
700+
class A:
701+
@property
702+
def x(self) -> int: return 1
703+
[builtins fixtures/property.py]
704+
[out]
705+
main: note: In function "f":
706+
707+
[case testForwardReferenceToDynamicallyTypedStaticMethod]
708+
def f(self) -> None:
709+
A.x(1).y
710+
A.x() # E: Too few arguments for "x"
711+
712+
class A:
713+
@staticmethod
714+
def x(x): pass
715+
[builtins fixtures/staticmethod.py]
716+
[out]
717+
main: note: In function "f":
718+
719+
[case testForwardReferenceToStaticallyTypedStaticMethod]
720+
def f(self) -> None:
721+
A.x(1).y # E: "str" has no attribute "y"
722+
A.x('') # E: Argument 1 to "x" of "A" has incompatible type "str"; expected "int"
723+
724+
class A:
725+
@staticmethod
726+
def x(a: int) -> str: return ''
727+
[builtins fixtures/staticmethod.py]
728+
[out]
729+
main: note: In function "f":
730+
731+
[case testForwardReferenceToDynamicallyTypedClassMethod]
732+
def f(self) -> None:
733+
A.x(1).y
734+
A.x() # E: Too few arguments for "x"
735+
736+
class A:
737+
@classmethod
738+
def x(cls, a): pass
739+
[builtins fixtures/classmethod.py]
740+
[out]
741+
main: note: In function "f":
742+
743+
[case testForwardReferenceToStaticallyTypedClassMethod]
744+
def f(self) -> None:
745+
A.x(1).y # E: "str" has no attribute "y"
746+
A.x('') # E: Argument 1 to "x" of "A" has incompatible type "str"; expected "int"
747+
748+
class A:
749+
@classmethod
750+
def x(cls, x: int) -> str: return ''
751+
[builtins fixtures/classmethod.py]
752+
[out]
753+
main: note: In function "f":
754+
755+
[case testForwardReferenceToDecoratedFunctionUsingMemberExpr]
756+
import m
757+
758+
def f(self) -> None:
759+
g(1).x # E: "str" has no attribute "x"
760+
761+
@m.dec
762+
def g(x: int) -> str: pass
763+
[file m.py]
764+
from typing import TypeVar
765+
T = TypeVar('T')
766+
def dec(f: T) -> T:
767+
return f
768+
[out]
769+
main: note: In function "f":
770+
771+
[case testForwardReferenceToFunctionWithMultipleDecorators]
772+
def f(self) -> None:
773+
g()
774+
g(1)
775+
776+
def dec(f):
777+
return f
778+
779+
@dec
780+
@dec2
781+
def g():
782+
pass
783+
784+
def dec2(f):
785+
return f
786+
787+
[case testForwardReferenceToDynamicallyTypedDecoratedStaticMethod]
788+
def f(self) -> None:
789+
A().f(1).y
790+
A().f()
791+
A().g(1).y
792+
A().g()
793+
794+
class A:
795+
@dec
796+
@staticmethod
797+
def f(self, x): pass
798+
@staticmethod
799+
@dec
800+
def g(self, x): pass
801+
802+
def dec(f): return f
803+
[builtins fixtures/staticmethod.py]
804+
805+
[case testForwardRefereceToDecoratedFunctionWithCallExpressionDecorator]
806+
def f(self) -> None:
807+
g()
808+
g(1)
809+
810+
@dec(1)
811+
def g(): pass
812+
813+
def dec(f): pass
814+
815+
609816
-- Conditional function definition
610817
-- -------------------------------
611818

@@ -874,7 +1081,7 @@ def f(x, y):
8741081
v = realtypes(x, y) # E: Argument 2 to "realtypes" has incompatible type "object"; expected "int"
8751082
return v # E: Incompatible return value type: expected builtins.int, got builtins.str
8761083
[out]
877-
main: note: In function "f":
1084+
main: note: In function "f":
8781085

8791086
[case testCallDocstringFunction-skip]
8801087
from typing import List

0 commit comments

Comments
 (0)