From 99436eefd6e6edca02d6b3f10fa6087265e3f3a3 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Fri, 25 Jun 2021 12:27:36 +0200 Subject: [PATCH 01/20] Add support for conditionally defined overloads --- mypy/fastparse.py | 53 ++++++++- test-data/unit/check-overloading.test | 157 ++++++++++++++++++++++++++ 2 files changed, 209 insertions(+), 1 deletion(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index a0d0ec8e34b0..56723ee295f0 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -39,7 +39,7 @@ from mypy import message_registry, errorcodes as codes from mypy.errors import Errors from mypy.options import Options -from mypy.reachability import mark_block_unreachable +from mypy.reachability import infer_reachability_of_if_statement, mark_block_unreachable try: # pull this into a final variable to make mypyc be quiet about the @@ -447,12 +447,50 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: ret: List[Statement] = [] current_overload: List[OverloadPart] = [] current_overload_name: Optional[str] = None + last_if_stmt: Optional[IfStmt] = None + last_if_overload: Optional[Union[Decorator, OverloadedFuncDef]] = None for stmt in stmts: if (current_overload_name is not None and isinstance(stmt, (Decorator, FuncDef)) and stmt.name == current_overload_name): + if last_if_overload is not None: + if isinstance(last_if_overload, OverloadedFuncDef): + current_overload.extend(last_if_overload.items) + else: + current_overload.append(last_if_overload) + last_if_stmt, last_if_overload = None, None current_overload.append(stmt) + elif ( + current_overload_name is not None + and isinstance(stmt, IfStmt) + and len(stmt.body[0].body) == 1 + and isinstance( + stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef)) + and stmt.body[0].body[0].name == current_overload_name + ): + # IfStmt only contains stmts relevant to current_overload. + # Check if stmts are reachable and add them to current_overload, + # otherwise skip IfStmt to allow subsequent overload + # or function definitions. + infer_reachability_of_if_statement(stmt, self.options) + if stmt.body[0].is_unreachable is True: + continue + if last_if_overload is not None: + if isinstance(last_if_overload, OverloadedFuncDef): + current_overload.extend(last_if_overload.items) + else: + current_overload.append(last_if_overload) + last_if_stmt, last_if_overload = None, None + last_if_overload = None + if isinstance(stmt.body[0].body[0], OverloadedFuncDef): + current_overload.extend(stmt.body[0].body[0].items) + else: + current_overload.append(stmt.body[0].body[0]) else: + if last_if_stmt is not None: + ret.append(last_if_stmt) + last_if_stmt, last_if_overload = None, None + if len(current_overload) == 1: ret.append(current_overload[0]) elif len(current_overload) > 1: @@ -466,6 +504,19 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: if isinstance(stmt, Decorator) and not unnamed_function(stmt.name): current_overload = [stmt] current_overload_name = stmt.name + elif ( + isinstance(stmt, IfStmt) + and len(stmt.body[0].body) == 1 + and isinstance( + stmt.body[0].body[0], (Decorator, OverloadedFuncDef)) + and infer_reachability_of_if_statement( + stmt, self.options + ) is None # type: ignore[func-returns-value] + and stmt.body[0].is_unreachable is False + ): + current_overload_name = stmt.body[0].body[0].name + last_if_stmt = stmt + last_if_overload = stmt.body[0].body[0] else: current_overload = [] current_overload_name = None diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index bf7acdc1cd51..d6831c03a69a 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -5339,3 +5339,160 @@ def register(cls: Any) -> Any: return None x = register(Foo) reveal_type(x) # N: Revealed type is "builtins.int" [builtins fixtures/dict.pyi] + +[case testOverloadIfBasic] +# flags: --always-true True +from typing import overload, Any + +class A: ... +class B: ... + +@overload +def f1(g: int) -> A: ... +if True: + @overload + def f1(g: str) -> B: ... +def f1(g: Any) -> Any: ... +reveal_type(f1(42)) # N: Revealed type is "__main__.A" +reveal_type(f1("Hello")) # N: Revealed type is "__main__.B" + +@overload +def f2(g: int) -> A: ... +@overload +def f2(g: bytes) -> A: ... +if not True: + @overload + def f2(g: str) -> B: ... +def f2(g: Any) -> Any: ... +reveal_type(f2(42)) # N: Revealed type is "__main__.A" +reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \ + # N: Possible overload variants: \ + # N: def f2(g: int) -> A \ + # N: def f2(g: bytes) -> A \ + # N: Revealed type is "Any" + +[case testOverloadIfSysVersion] +# flags: --python-version 3.9 +from typing import overload, Any +import sys + +class A: ... +class B: ... + +@overload +def f1(g: int) -> A: ... +if sys.version_info >= (3, 9): + @overload + def f1(g: str) -> B: ... +def f1(g: Any) -> Any: ... +reveal_type(f1(42)) # N: Revealed type is "__main__.A" +reveal_type(f1("Hello")) # N: Revealed type is "__main__.B" + +@overload +def f2(g: int) -> A: ... +@overload +def f2(g: bytes) -> A: ... +if sys.version_info >= (3, 10): + @overload + def f2(g: str) -> B: ... +def f2(g: Any) -> Any: ... +reveal_type(f2(42)) # N: Revealed type is "__main__.A" +reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \ + # N: Possible overload variants: \ + # N: def f2(g: int) -> A \ + # N: def f2(g: bytes) -> A \ + # N: Revealed type is "Any" +[builtins fixtures/tuple.pyi] + +[case testOverloadIfMatching] +from typing import overload, Any + +class A: ... +class B: ... +class C: ... + +@overload +def f1(g: int) -> A: ... +if True: + # Some comment + @overload + def f1(g: str) -> B: ... +def f1(g: Any) -> Any: ... +reveal_type(f1(42)) # N: Revealed type is "__main__.A" +reveal_type(f1("Hello")) # N: Revealed type is "__main__.B" + +@overload +def f2(g: int) -> A: ... +if True: + @overload + def f2(g: bytes) -> B: ... + @overload + def f2(g: str) -> C: ... +def f2(g: Any) -> Any: ... +reveal_type(f2(42)) # N: Revealed type is "__main__.A" +reveal_type(f2("Hello")) # N: Revealed type is "__main__.C" + +@overload +def f3(g: int) -> A: ... +@overload +def f3(g: str) -> B: ... +if True: + def f3(g: Any) -> Any: ... +reveal_type(f3(42)) # N: Revealed type is "__main__.A" +reveal_type(f3("Hello")) # N: Revealed type is "__main__.B" + +if True: + @overload + def f4(g: int) -> A: ... +@overload +def f4(g: str) -> B: ... +def f4(g: Any) -> Any: ... +reveal_type(f4(42)) # N: Revealed type is "__main__.A" +reveal_type(f4("Hello")) # N: Revealed type is "__main__.B" + +if True: + # Some comment + @overload + def f5(g: int) -> A: ... + @overload + def f5(g: str) -> B: ... +def f5(g: Any) -> Any: ... +reveal_type(f5(42)) # N: Revealed type is "__main__.A" +reveal_type(f5("Hello")) # N: Revealed type is "__main__.B" + +[case testOverloadIfNotMatching] +from typing import overload, Any + +class A: ... +class B: ... +class C: ... + +@overload # E: An overloaded function outside a stub file must have an implementation +def f1(g: int) -> A: ... +@overload +def f1(g: bytes) -> B: ... +if True: + @overload # E: Name "f1" already defined on line 7 \ + # E: Single overload definition, multiple required + def f1(g: str) -> C: ... + pass # Some other action +def f1(g: Any) -> Any: ... # E: Name "f1" already defined on line 7 +reveal_type(f1(42)) # N: Revealed type is "__main__.A" +reveal_type(f1("Hello")) # E: No overload variant of "f1" matches argument type "str" \ + # N: Possible overload variants: \ + # N: def f1(g: int) -> A \ + # N: def f1(g: bytes) -> B \ + # N: Revealed type is "Any" + +if True: + pass # Some other action + @overload # E: Single overload definition, multiple required + def f2(g: int) -> A: ... +@overload # E: Name "f2" already defined on line 21 +def f2(g: bytes) -> B: ... +@overload +def f2(g: str) -> C: ... +def f2(g: Any) -> Any: ... +reveal_type(f2(42)) # N: Revealed type is "__main__.A" +reveal_type(f2("Hello")) # N: Revealed type is "__main__.A" \ + # E: Argument 1 to "f2" has incompatible type "str"; expected "int" From 7e7502b35b40ff504f58d0f59875754455bb320b Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Fri, 25 Jun 2021 23:52:17 +0200 Subject: [PATCH 02/20] Bugfix --- mypy/fastparse.py | 7 ++++-- test-data/unit/check-overloading.test | 31 +++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 56723ee295f0..efdbb29cd787 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -509,11 +509,12 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: and len(stmt.body[0].body) == 1 and isinstance( stmt.body[0].body[0], (Decorator, OverloadedFuncDef)) - and infer_reachability_of_if_statement( + and infer_reachability_of_if_statement( # type: ignore[func-returns-value] stmt, self.options - ) is None # type: ignore[func-returns-value] + ) is None and stmt.body[0].is_unreachable is False ): + current_overload = [] current_overload_name = stmt.body[0].body[0].name last_if_stmt = stmt last_if_overload = stmt.body[0].body[0] @@ -526,6 +527,8 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: ret.append(current_overload[0]) elif len(current_overload) > 1: ret.append(OverloadedFuncDef(current_overload)) + elif last_if_stmt is not None: + ret.append(last_if_stmt) return ret def in_method_scope(self) -> bool: diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index d6831c03a69a..8cf7833056ce 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -5496,3 +5496,34 @@ def f2(g: Any) -> Any: ... reveal_type(f2(42)) # N: Revealed type is "__main__.A" reveal_type(f2("Hello")) # N: Revealed type is "__main__.A" \ # E: Argument 1 to "f2" has incompatible type "str"; expected "int" + +[case testOverloadIfOldStyle] +# flags: --always-false var_false --always-true var_true +from typing import overload, Any + +class A: ... +class B: ... + +var_true = True +var_false = False + +if var_false: + @overload + def f1(g: int) -> A: ... + @overload + def f1(g: str) -> B: ... + def f1(g: Any) -> Any: ... +elif var_true: + @overload + def f1(g: int) -> A: ... + @overload + def f1(g: str) -> B: ... + def f1(g: Any) -> Any: ... +else: + @overload + def f1(g: int) -> A: ... + @overload + def f1(g: str) -> B: ... + def f1(g: Any) -> Any: ... +reveal_type(f1(42)) # N: Revealed type is "__main__.A" +reveal_type(f1("Hello")) # N: Revealed type is "__main__.B" From 3effa15cb9897c4bef9bc1024436dfb0d7a09833 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 14 Dec 2021 22:28:42 +0100 Subject: [PATCH 03/20] Redo logic to support elif + else --- mypy/fastparse.py | 99 ++++++++++++++++++++++----- test-data/unit/check-overloading.test | 8 ++- 2 files changed, 85 insertions(+), 22 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index efdbb29cd787..35b8e7a599db 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -450,10 +450,24 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: last_if_stmt: Optional[IfStmt] = None last_if_overload: Optional[Union[Decorator, OverloadedFuncDef]] = None for stmt in stmts: + if_overload_name: Optional[str] = None + if_block_with_overload: Optional[Block] = None + if ( + isinstance(stmt, IfStmt) + and len(stmt.body[0].body) == 1 + and isinstance(stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef)) + ): + # Check IfStmt block to determine if function overloads can be merged + if_overload_name = self._check_ifstmt_for_overloads(stmt) + if if_overload_name is not None: + if_block_with_overload = self._get_executable_if_block_with_overloads(stmt) + if (current_overload_name is not None and isinstance(stmt, (Decorator, FuncDef)) and stmt.name == current_overload_name): if last_if_overload is not None: + # Last stmt was an IfStmt with same overload name + # Add overloads to current_overload if isinstance(last_if_overload, OverloadedFuncDef): current_overload.extend(last_if_overload.items) else: @@ -463,29 +477,26 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: elif ( current_overload_name is not None and isinstance(stmt, IfStmt) - and len(stmt.body[0].body) == 1 - and isinstance( - stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef)) - and stmt.body[0].body[0].name == current_overload_name + and if_overload_name == current_overload_name ): # IfStmt only contains stmts relevant to current_overload. # Check if stmts are reachable and add them to current_overload, # otherwise skip IfStmt to allow subsequent overload # or function definitions. - infer_reachability_of_if_statement(stmt, self.options) - if stmt.body[0].is_unreachable is True: + if if_block_with_overload is None: continue if last_if_overload is not None: + # Last stmt was an IfStmt with same overload name + # Add overloads to current_overload if isinstance(last_if_overload, OverloadedFuncDef): current_overload.extend(last_if_overload.items) else: current_overload.append(last_if_overload) last_if_stmt, last_if_overload = None, None - last_if_overload = None - if isinstance(stmt.body[0].body[0], OverloadedFuncDef): - current_overload.extend(stmt.body[0].body[0].items) + if isinstance(if_block_with_overload.body[0], OverloadedFuncDef): + current_overload.extend(if_block_with_overload.body[0].items) else: - current_overload.append(stmt.body[0].body[0]) + current_overload.append(if_block_with_overload.body[0]) else: if last_if_stmt is not None: ret.append(last_if_stmt) @@ -506,18 +517,13 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: current_overload_name = stmt.name elif ( isinstance(stmt, IfStmt) - and len(stmt.body[0].body) == 1 - and isinstance( - stmt.body[0].body[0], (Decorator, OverloadedFuncDef)) - and infer_reachability_of_if_statement( # type: ignore[func-returns-value] - stmt, self.options - ) is None - and stmt.body[0].is_unreachable is False + and if_overload_name is not None + and if_block_with_overload is not None ): current_overload = [] - current_overload_name = stmt.body[0].body[0].name + current_overload_name = if_overload_name last_if_stmt = stmt - last_if_overload = stmt.body[0].body[0] + last_if_overload = if_block_with_overload.body[0] else: current_overload = [] current_overload_name = None @@ -531,6 +537,61 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: ret.append(last_if_stmt) return ret + def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]: + """Check if IfStmt contains only overloads with the same name. + Return overload_name if found, None otherwise. + """ + # Check that block only contains a single Decorator, FuncDef, or OverloadedFuncDef. + # Multiple overloads have already been merged as OverloadedFuncDef. + if not ( + len(stmt.body[0].body) == 1 + or isinstance(stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef)) + ): + return None + + overload_name = stmt.body[0].body[0].name + if stmt.else_body is None: + return overload_name + + if isinstance(stmt.else_body, Block) and len(stmt.else_body.body) == 1: + # For elif: else_body contains an IfStmt itself -> do a recursive check. + if ( + isinstance(stmt.else_body.body[0], (Decorator, FuncDef, OverloadedFuncDef)) + and stmt.else_body.body[0].name == overload_name + ): + return overload_name + if ( + isinstance(stmt.else_body.body[0], IfStmt) + and self._check_ifstmt_for_overloads(stmt.else_body.body[0]) == overload_name + ): + return overload_name + + return None + + def _get_executable_if_block_with_overloads(self, stmt: IfStmt) -> Optional[Block]: + """Return block from IfStmt that will get executed. + + Only returns block if sure that alternative blocks are unreachable. + """ + infer_reachability_of_if_statement(stmt, self.options) + if ( + stmt.else_body is None + or stmt.body[0].is_unreachable is False + and stmt.else_body.is_unreachable is False + ): + # The truth value is unknown, thus not conclusive + return None + if stmt.else_body.is_unreachable is True: + # else_body will be set unreachable if condition is always True + return stmt.body[0] + if stmt.body[0].is_unreachable is True: + # body will be set unreachable if condition is always False + # else_body can contain an IfStmt itself (for elif) -> do a recursive check + if isinstance(stmt.else_body.body[0], IfStmt): + return self._get_executable_if_block_with_overloads(stmt.else_body.body[0]) + return stmt.else_body + return None + def in_method_scope(self) -> bool: return self.class_and_function_stack[-2:] == ['C', 'F'] diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index ea41cc940745..365f6c9a2a3a 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -5488,6 +5488,7 @@ reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type [builtins fixtures/tuple.pyi] [case testOverloadIfMatching] +# flags: --always-true True from typing import overload, Any class A: ... @@ -5544,6 +5545,7 @@ reveal_type(f5(42)) # N: Revealed type is "__main__.A" reveal_type(f5("Hello")) # N: Revealed type is "__main__.B" [case testOverloadIfNotMatching] +# flags: --always-true True from typing import overload, Any class A: ... @@ -5555,11 +5557,11 @@ def f1(g: int) -> A: ... @overload def f1(g: bytes) -> B: ... if True: - @overload # E: Name "f1" already defined on line 7 \ + @overload # E: Name "f1" already defined on line 8 \ # E: Single overload definition, multiple required def f1(g: str) -> C: ... pass # Some other action -def f1(g: Any) -> Any: ... # E: Name "f1" already defined on line 7 +def f1(g: Any) -> Any: ... # E: Name "f1" already defined on line 8 reveal_type(f1(42)) # N: Revealed type is "__main__.A" reveal_type(f1("Hello")) # E: No overload variant of "f1" matches argument type "str" \ # N: Possible overload variants: \ @@ -5571,7 +5573,7 @@ if True: pass # Some other action @overload # E: Single overload definition, multiple required def f2(g: int) -> A: ... -@overload # E: Name "f2" already defined on line 21 +@overload # E: Name "f2" already defined on line 22 def f2(g: bytes) -> B: ... @overload def f2(g: str) -> C: ... From ee6ad3ce2777e90bcfce99e42b017e2359b2b4ca Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 14 Dec 2021 22:43:49 +0100 Subject: [PATCH 04/20] Fix typing issues --- mypy/fastparse.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 35b8e7a599db..fcc810007836 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -448,7 +448,7 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: current_overload: List[OverloadPart] = [] current_overload_name: Optional[str] = None last_if_stmt: Optional[IfStmt] = None - last_if_overload: Optional[Union[Decorator, OverloadedFuncDef]] = None + last_if_overload: Optional[Union[Decorator, FuncDef, OverloadedFuncDef]] = None for stmt in stmts: if_overload_name: Optional[str] = None if_block_with_overload: Optional[Block] = None @@ -496,7 +496,9 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: if isinstance(if_block_with_overload.body[0], OverloadedFuncDef): current_overload.extend(if_block_with_overload.body[0].items) else: - current_overload.append(if_block_with_overload.body[0]) + current_overload.append( + cast(Union[Decorator, FuncDef], if_block_with_overload.body[0]) + ) else: if last_if_stmt is not None: ret.append(last_if_stmt) @@ -523,7 +525,10 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: current_overload = [] current_overload_name = if_overload_name last_if_stmt = stmt - last_if_overload = if_block_with_overload.body[0] + last_if_overload = cast( + Union[Decorator, FuncDef, OverloadedFuncDef], + if_block_with_overload.body[0] + ) else: current_overload = [] current_overload_name = None @@ -549,7 +554,9 @@ def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]: ): return None - overload_name = stmt.body[0].body[0].name + overload_name = cast( + Union[Decorator, FuncDef, OverloadedFuncDef], stmt.body[0].body[0] + ).name if stmt.else_body is None: return overload_name From bc486e00b35b410c40fe074862dffce6e74453af Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 15 Dec 2021 01:17:25 +0100 Subject: [PATCH 05/20] Fix small issue with merging IfStmt --- mypy/fastparse.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index fcc810007836..0f716db49ec2 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -449,6 +449,7 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: current_overload_name: Optional[str] = None last_if_stmt: Optional[IfStmt] = None last_if_overload: Optional[Union[Decorator, FuncDef, OverloadedFuncDef]] = None + last_if_stmt_overload_name: Optional[str] = None for stmt in stmts: if_overload_name: Optional[str] = None if_block_with_overload: Optional[Block] = None @@ -502,8 +503,13 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: else: if last_if_stmt is not None: ret.append(last_if_stmt) + last_if_stmt_overload_name = current_overload_name last_if_stmt, last_if_overload = None, None + if current_overload and current_overload_name == last_if_stmt_overload_name: + # Remove last stmt (IfStmt) from ret if the overload names matched + # Only happens if no executable block had been found in IfStmt + ret.pop() if len(current_overload) == 1: ret.append(current_overload[0]) elif len(current_overload) > 1: @@ -520,15 +526,16 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: elif ( isinstance(stmt, IfStmt) and if_overload_name is not None - and if_block_with_overload is not None ): current_overload = [] current_overload_name = if_overload_name last_if_stmt = stmt - last_if_overload = cast( - Union[Decorator, FuncDef, OverloadedFuncDef], - if_block_with_overload.body[0] - ) + last_if_stmt_overload_name = None + if if_block_with_overload is not None: + last_if_overload = cast( + Union[Decorator, FuncDef, OverloadedFuncDef], + if_block_with_overload.body[0] + ) else: current_overload = [] current_overload_name = None From a1370e08077b5c7e8371f8effe25785bf99a94b8 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 15 Dec 2021 01:19:23 +0100 Subject: [PATCH 06/20] Update existing tests --- test-data/unit/check-overloading.test | 203 ++++++++++++++------------ 1 file changed, 113 insertions(+), 90 deletions(-) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 365f6c9a2a3a..91ff9e5080ce 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -5425,190 +5425,213 @@ def f_f(arg): ... [case testOverloadIfBasic] # flags: --always-true True -from typing import overload, Any +from typing import overload class A: ... class B: ... +class C: ... + +# ----- +# Test basic overload merging +# ----- @overload -def f1(g: int) -> A: ... +def f1(g: A) -> A: ... if True: @overload - def f1(g: str) -> B: ... -def f1(g: Any) -> Any: ... -reveal_type(f1(42)) # N: Revealed type is "__main__.A" -reveal_type(f1("Hello")) # N: Revealed type is "__main__.B" + def f1(g: B) -> B: ... +def f1(g): ... +reveal_type(f1(A())) # N: Revealed type is "__main__.A" +reveal_type(f1(B())) # N: Revealed type is "__main__.B" @overload -def f2(g: int) -> A: ... +def f2(g: A) -> A: ... @overload -def f2(g: bytes) -> A: ... +def f2(g: B) -> B: ... if not True: @overload - def f2(g: str) -> B: ... -def f2(g: Any) -> Any: ... -reveal_type(f2(42)) # N: Revealed type is "__main__.A" -reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \ - # N: Possible overload variants: \ - # N: def f2(g: int) -> A \ - # N: def f2(g: bytes) -> A \ - # N: Revealed type is "Any" + def f2(g: C) -> C: ... +def f2(g): ... +reveal_type(f2(A())) # N: Revealed type is "__main__.A" +reveal_type(f2(C())) # E: No overload variant of "f2" matches argument type "C" \ + # N: Possible overload variants: \ + # N: def f2(g: A) -> A \ + # N: def f2(g: B) -> B \ + # N: Revealed type is "Any" [case testOverloadIfSysVersion] # flags: --python-version 3.9 -from typing import overload, Any +from typing import overload import sys class A: ... class B: ... +class C: ... + +# ----- +# "Real" world example +# Test overload merging for sys.version_info +# ----- @overload -def f1(g: int) -> A: ... +def f1(g: A) -> A: ... if sys.version_info >= (3, 9): @overload - def f1(g: str) -> B: ... -def f1(g: Any) -> Any: ... -reveal_type(f1(42)) # N: Revealed type is "__main__.A" -reveal_type(f1("Hello")) # N: Revealed type is "__main__.B" + def f1(g: B) -> B: ... +def f1(g): ... +reveal_type(f1(A())) # N: Revealed type is "__main__.A" +reveal_type(f1(B())) # N: Revealed type is "__main__.B" @overload -def f2(g: int) -> A: ... +def f2(g: A) -> A: ... @overload -def f2(g: bytes) -> A: ... +def f2(g: B) -> B: ... if sys.version_info >= (3, 10): @overload - def f2(g: str) -> B: ... -def f2(g: Any) -> Any: ... -reveal_type(f2(42)) # N: Revealed type is "__main__.A" -reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \ - # N: Possible overload variants: \ - # N: def f2(g: int) -> A \ - # N: def f2(g: bytes) -> A \ - # N: Revealed type is "Any" + def f2(g: C) -> C: ... +def f2(g): ... +reveal_type(f2(A())) # N: Revealed type is "__main__.A" +reveal_type(f2(C())) # E: No overload variant of "f2" matches argument type "C" \ + # N: Possible overload variants: \ + # N: def f2(g: A) -> A \ + # N: def f2(g: B) -> B \ + # N: Revealed type is "Any" [builtins fixtures/tuple.pyi] -[case testOverloadIfMatching] +[case testOverloadIfMerging] # flags: --always-true True -from typing import overload, Any +from typing import overload class A: ... class B: ... class C: ... +# ----- +# Test overload merging +# ----- + @overload -def f1(g: int) -> A: ... +def f1(g: A) -> A: ... if True: # Some comment @overload - def f1(g: str) -> B: ... -def f1(g: Any) -> Any: ... -reveal_type(f1(42)) # N: Revealed type is "__main__.A" -reveal_type(f1("Hello")) # N: Revealed type is "__main__.B" + def f1(g: B) -> B: ... +def f1(g): ... +reveal_type(f1(A())) # N: Revealed type is "__main__.A" +reveal_type(f1(B())) # N: Revealed type is "__main__.B" @overload -def f2(g: int) -> A: ... +def f2(g: A) -> A: ... if True: @overload def f2(g: bytes) -> B: ... @overload - def f2(g: str) -> C: ... -def f2(g: Any) -> Any: ... -reveal_type(f2(42)) # N: Revealed type is "__main__.A" -reveal_type(f2("Hello")) # N: Revealed type is "__main__.C" + def f2(g: B) -> C: ... +def f2(g): ... +reveal_type(f2(A())) # N: Revealed type is "__main__.A" +reveal_type(f2(B())) # N: Revealed type is "__main__.C" @overload -def f3(g: int) -> A: ... +def f3(g: A) -> A: ... @overload -def f3(g: str) -> B: ... +def f3(g: B) -> B: ... if True: - def f3(g: Any) -> Any: ... -reveal_type(f3(42)) # N: Revealed type is "__main__.A" -reveal_type(f3("Hello")) # N: Revealed type is "__main__.B" + def f3(g): ... +reveal_type(f3(A())) # N: Revealed type is "__main__.A" +reveal_type(f3(B())) # N: Revealed type is "__main__.B" if True: @overload - def f4(g: int) -> A: ... + def f4(g: A) -> A: ... @overload -def f4(g: str) -> B: ... -def f4(g: Any) -> Any: ... -reveal_type(f4(42)) # N: Revealed type is "__main__.A" -reveal_type(f4("Hello")) # N: Revealed type is "__main__.B" +def f4(g: B) -> B: ... +def f4(g): ... +reveal_type(f4(A())) # N: Revealed type is "__main__.A" +reveal_type(f4(B())) # N: Revealed type is "__main__.B" if True: # Some comment @overload - def f5(g: int) -> A: ... + def f5(g: A) -> A: ... @overload - def f5(g: str) -> B: ... -def f5(g: Any) -> Any: ... -reveal_type(f5(42)) # N: Revealed type is "__main__.A" -reveal_type(f5("Hello")) # N: Revealed type is "__main__.B" + def f5(g: B) -> B: ... +def f5(g): ... +reveal_type(f5(A())) # N: Revealed type is "__main__.A" +reveal_type(f5(B())) # N: Revealed type is "__main__.B" -[case testOverloadIfNotMatching] +[case testOverloadIfNotMerging] # flags: --always-true True -from typing import overload, Any +from typing import overload class A: ... class B: ... class C: ... +# ----- +# Don't merge if IfStmt contains nodes other than overloads +# ----- + @overload # E: An overloaded function outside a stub file must have an implementation -def f1(g: int) -> A: ... +def f1(g: A) -> A: ... @overload -def f1(g: bytes) -> B: ... +def f1(g: B) -> B: ... if True: - @overload # E: Name "f1" already defined on line 8 \ + @overload # E: Name "f1" already defined on line 12 \ # E: Single overload definition, multiple required - def f1(g: str) -> C: ... + def f1(g: C) -> C: ... pass # Some other action -def f1(g: Any) -> Any: ... # E: Name "f1" already defined on line 8 -reveal_type(f1(42)) # N: Revealed type is "__main__.A" -reveal_type(f1("Hello")) # E: No overload variant of "f1" matches argument type "str" \ +def f1(g): ... # E: Name "f1" already defined on line 12 +reveal_type(f1(A())) # N: Revealed type is "__main__.A" +reveal_type(f1(C())) # E: No overload variant of "f1" matches argument type "C" \ # N: Possible overload variants: \ - # N: def f1(g: int) -> A \ - # N: def f1(g: bytes) -> B \ + # N: def f1(g: A) -> A \ + # N: def f1(g: B) -> B \ # N: Revealed type is "Any" if True: pass # Some other action @overload # E: Single overload definition, multiple required - def f2(g: int) -> A: ... -@overload # E: Name "f2" already defined on line 22 -def f2(g: bytes) -> B: ... + def f2(g: A) -> A: ... +@overload # E: Name "f2" already defined on line 26 +def f2(g: B) -> B: ... @overload -def f2(g: str) -> C: ... -def f2(g: Any) -> Any: ... -reveal_type(f2(42)) # N: Revealed type is "__main__.A" -reveal_type(f2("Hello")) # N: Revealed type is "__main__.A" \ - # E: Argument 1 to "f2" has incompatible type "str"; expected "int" +def f2(g: C) -> C: ... +def f2(g): ... +reveal_type(f2(A())) # N: Revealed type is "__main__.A" +reveal_type(f2(C())) # N: Revealed type is "__main__.A" \ + # E: Argument 1 to "f2" has incompatible type "C"; expected "A" [case testOverloadIfOldStyle] # flags: --always-false var_false --always-true var_true -from typing import overload, Any +from typing import overload class A: ... class B: ... +# ----- +# Test old style to make sure it still works +# ----- + var_true = True var_false = False if var_false: @overload - def f1(g: int) -> A: ... + def f1(g: A) -> A: ... @overload - def f1(g: str) -> B: ... - def f1(g: Any) -> Any: ... + def f1(g: B) -> B: ... + def f1(g): ... elif var_true: @overload - def f1(g: int) -> A: ... + def f1(g: A) -> A: ... @overload - def f1(g: str) -> B: ... - def f1(g: Any) -> Any: ... + def f1(g: B) -> B: ... + def f1(g): ... else: @overload - def f1(g: int) -> A: ... + def f1(g: A) -> A: ... @overload - def f1(g: str) -> B: ... - def f1(g: Any) -> Any: ... -reveal_type(f1(42)) # N: Revealed type is "__main__.A" -reveal_type(f1("Hello")) # N: Revealed type is "__main__.B" + def f1(g: B) -> B: ... + def f1(g): ... +reveal_type(f1(A())) # N: Revealed type is "__main__.A" +reveal_type(f1(B())) # N: Revealed type is "__main__.B" From 0d2dee36e0ef13fd09c4b2935d2b96a9f4244be9 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 15 Dec 2021 01:20:16 +0100 Subject: [PATCH 07/20] Add additional tests --- test-data/unit/check-overloading.test | 466 ++++++++++++++++++++++++++ 1 file changed, 466 insertions(+) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 91ff9e5080ce..39f0460a769e 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -5635,3 +5635,469 @@ else: def f1(g): ... reveal_type(f1(A())) # N: Revealed type is "__main__.A" reveal_type(f1(B())) # N: Revealed type is "__main__.B" + +[case testOverloadIfElse] +# flags: --always-true True --always-false False +from typing import overload + +class A: ... +class B: ... +class C: ... +class D: ... + +# ----- +# Match the first always-true block +# ----- + +@overload +def f1(x: A) -> A: ... +if True: + @overload + def f1(x: B) -> B: ... +elif False: + @overload + def f1(x: C) -> C: ... +else: + @overload + def f1(x: D) -> D: ... +def f1(x): ... +reveal_type(f1(A())) # N: Revealed type is "__main__.A" +reveal_type(f1(B())) # N: Revealed type is "__main__.B" +reveal_type(f1(C())) # E: No overload variant of "f1" matches argument type "C" \ + # N: Possible overload variants: \ + # N: def f1(x: A) -> A \ + # N: def f1(x: B) -> B \ + # N: Revealed type is "Any" + +@overload +def f2(x: A) -> A: ... +if False: + @overload + def f2(x: B) -> B: ... +elif True: + @overload + def f2(x: C) -> C: ... +else: + @overload + def f2(x: D) -> D: ... +def f2(x): ... +reveal_type(f2(A())) # N: Revealed type is "__main__.A" +reveal_type(f2(B())) # E: No overload variant of "f2" matches argument type "B" \ + # N: Possible overload variants: \ + # N: def f2(x: A) -> A \ + # N: def f2(x: C) -> C \ + # N: Revealed type is "Any" +reveal_type(f2(C())) # N: Revealed type is "__main__.C" + +@overload +def f3(x: A) -> A: ... +if False: + @overload + def f3(x: B) -> B: ... +elif False: + @overload + def f3(x: C) -> C: ... +else: + @overload + def f3(x: D) -> D: ... +def f3(x): ... +reveal_type(f3(A())) # N: Revealed type is "__main__.A" +reveal_type(f3(C())) # E: No overload variant of "f3" matches argument type "C" \ + # N: Possible overload variants: \ + # N: def f3(x: A) -> A \ + # N: def f3(x: D) -> D \ + # N: Revealed type is "Any" +reveal_type(f3(D())) # N: Revealed type is "__main__.D" + +[case testOverloadIfElse2] +# flags: --always-true True +from typing import overload + +class A: ... +class B: ... +class C: ... +class D: ... + +# ----- +# Match the first always-true block +# Don't merge overloads if can't be certain about execution of block +# ----- + +@overload +def f1(x: A) -> A: ... +if True: + @overload + def f1(x: B) -> B: ... +else: + @overload + def f1(x: D) -> D: ... +def f1(x): ... +reveal_type(f1(A())) # N: Revealed type is "__main__.A" +reveal_type(f1(B())) # N: Revealed type is "__main__.B" +reveal_type(f1(D())) # E: No overload variant of "f1" matches argument type "D" \ + # N: Possible overload variants: \ + # N: def f1(x: A) -> A \ + # N: def f1(x: B) -> B \ + # N: Revealed type is "Any" + +@overload +def f2(x: A) -> A: ... +if True: + @overload + def f2(x: B) -> B: ... +elif maybe_true: + @overload + def f2(x: C) -> C: ... +else: + @overload + def f2(x: D) -> D: ... +def f2(x): ... +reveal_type(f2(A())) # N: Revealed type is "__main__.A" +reveal_type(f2(B())) # N: Revealed type is "__main__.B" +reveal_type(f2(C())) # E: No overload variant of "f2" matches argument type "C" \ + # N: Possible overload variants: \ + # N: def f2(x: A) -> A \ + # N: def f2(x: B) -> B \ + # N: Revealed type is "Any" + +@overload # E: Single overload definition, multiple required +def f3(x: A) -> A: ... +if maybe_true: + @overload + def f3(x: B) -> B: ... +elif True: + @overload + def f3(x: C) -> C: ... +else: + @overload + def f3(x: D) -> D: ... +def f3(x): ... +reveal_type(f3(A())) # N: Revealed type is "__main__.A" +reveal_type(f3(B())) # E: No overload variant of "f3" matches argument type "B" \ + # N: Possible overload variant: \ + # N: def f3(x: A) -> A \ + # N: Revealed type is "Any" + +@overload # E: Single overload definition, multiple required +def f4(x: A) -> A: ... +if maybe_true: + @overload + def f4(x: B) -> B: ... +else: + @overload + def f4(x: D) -> D: ... +def f4(x): ... +reveal_type(f4(A())) # N: Revealed type is "__main__.A" +reveal_type(f4(B())) # E: No overload variant of "f4" matches argument type "B" \ + # N: Possible overload variant: \ + # N: def f4(x: A) -> A \ + # N: Revealed type is "Any" + +[case testOverloadIfElse3] +# flags: --always-false False +from typing import overload + +class A: ... +class B: ... +class C: ... +class D: ... + +# ----- +# Match the first always-true block +# Don't merge overloads if can't be certain about execution of block +# ----- + +@overload +def f1(x: A) -> A: ... +if False: + @overload + def f1(x: B) -> B: ... +else: + @overload + def f1(x: D) -> D: ... +def f1(x): ... +reveal_type(f1(A())) # N: Revealed type is "__main__.A" +reveal_type(f1(B())) # E: No overload variant of "f1" matches argument type "B" \ + # N: Possible overload variants: \ + # N: def f1(x: A) -> A \ + # N: def f1(x: D) -> D \ + # N: Revealed type is "Any" +reveal_type(f1(D())) # N: Revealed type is "__main__.D" + +@overload # E: Single overload definition, multiple required +def f2(x: A) -> A: ... +if False: + @overload + def f2(x: B) -> B: ... +elif maybe_true: + @overload + def f2(x: C) -> C: ... +else: + @overload + def f2(x: D) -> D: ... +def f2(x): ... +reveal_type(f2(A())) # N: Revealed type is "__main__.A" +reveal_type(f2(C())) # E: No overload variant of "f2" matches argument type "C" \ + # N: Possible overload variant: \ + # N: def f2(x: A) -> A \ + # N: Revealed type is "Any" + +@overload # E: Single overload definition, multiple required +def f3(x: A) -> A: ... +if maybe_true: + @overload + def f3(x: B) -> B: ... +elif False: + @overload + def f3(x: C) -> C: ... +else: + @overload + def f3(x: D) -> D: ... +def f3(x): ... +reveal_type(f3(A())) # N: Revealed type is "__main__.A" +reveal_type(f3(B())) # E: No overload variant of "f3" matches argument type "B" \ + # N: Possible overload variant: \ + # N: def f3(x: A) -> A \ + # N: Revealed type is "Any" + +[case testOverloadIfSkipUnknownExecution] +# flags: --always-true True +from typing import overload + +class A: ... +class B: ... +class C: ... +class D: ... + +# ----- +# If blocks should be skipped if execution can't be certain +# Overload name must match outer name +# ----- + +@overload # E: Single overload definition, multiple required +def f1(x: A) -> A: ... +if maybe_true: + @overload + def f1(x: B) -> B: ... +def f1(x): ... +reveal_type(f1(A())) # N: Revealed type is "__main__.A" + +if maybe_true: + @overload + def f2(x: A) -> A: ... +@overload +def f2(x: B) -> B: ... +@overload +def f2(x: C) -> C: ... +def f2(x): ... +reveal_type(f2(A())) # E: No overload variant of "f2" matches argument type "A" \ + # N: Possible overload variants: \ + # N: def f2(x: B) -> B \ + # N: def f2(x: C) -> C \ + # N: Revealed type is "Any" + +if True: + @overload # E: Single overload definition, multiple required + def f3(x: A) -> A: ... + if maybe_true: + @overload + def f3(x: B) -> B: ... + def f3(x): ... +reveal_type(f3(A())) # N: Revealed type is "__main__.A" + +if True: + if maybe_true: + @overload + def f4(x: A) -> A: ... + @overload + def f4(x: B) -> B: ... + @overload + def f4(x: C) -> C: ... + def f4(x): ... +reveal_type(f4(A())) # E: No overload variant of "f4" matches argument type "A" \ + # N: Possible overload variants: \ + # N: def f4(x: B) -> B \ + # N: def f4(x: C) -> C \ + # N: Revealed type is "Any" + +[case testOverloadIfDontSkipUnrelatedOverload] +# flags: --always-true True +from typing import overload + +class A: ... +class B: ... +class C: ... +class D: ... + +# ----- +# Don't skip if block if overload name doesn't match outer name +# ----- + +@overload # E: Single overload definition, multiple required +def f1(x: A) -> A: ... +if maybe_true: # E: Name "maybe_true" is not defined + @overload # E: Single overload definition, multiple required + def g1(x: B) -> B: ... +def f1(x): ... # E: Name "f1" already defined on line 13 +reveal_type(f1(A())) # N: Revealed type is "__main__.A" + +if maybe_true: # E: Name "maybe_true" is not defined + @overload # E: Single overload definition, multiple required + def g2(x: A) -> A: ... +@overload +def f2(x: B) -> B: ... +@overload +def f2(x: C) -> C: ... +def f2(x): ... +reveal_type(f2(A())) # E: No overload variant of "f2" matches argument type "A" \ + # N: Possible overload variants: \ + # N: def f2(x: B) -> B \ + # N: def f2(x: C) -> C \ + # N: Revealed type is "Any" + +if True: + @overload # E: Single overload definition, multiple required + def f3(x: A) -> A: ... + if maybe_true: # E: Name "maybe_true" is not defined + @overload # E: Single overload definition, multiple required + def g3(x: B) -> B: ... + def f3(x): ... +reveal_type(f3(A())) # N: Revealed type is "__main__.A" + +if True: + if maybe_true: # E: Name "maybe_true" is not defined + @overload # E: Single overload definition, multiple required + def g4(x: A) -> A: ... + @overload + def f4(x: B) -> B: ... + @overload + def f4(x: C) -> C: ... + def f4(x): ... +reveal_type(f4(A())) # E: No overload variant of "f4" matches argument type "A" \ + # N: Possible overload variants: \ + # N: def f4(x: B) -> B \ + # N: def f4(x: C) -> C \ + # N: Revealed type is "Any" + +[case testOverloadIfNotMergingDifferentNames] +# flags: --always-true True +from typing import overload + +class A: ... +class B: ... +class C: ... +class D: ... + +# ----- +# Don't merge overloads if IfStmts contains overload with different name +# ----- + +@overload # E: An overloaded function outside a stub file must have an implementation +def f1(x: A) -> A: ... +@overload +def f1(x: B) -> B: ... +if True: + @overload # E: Single overload definition, multiple required + def g1(x: C) -> C: ... +def f1(x): ... # E: Name "f1" already defined on line 13 +reveal_type(f1(A())) # N: Revealed type is "__main__.A" +reveal_type(f1(C())) # E: No overload variant of "f1" matches argument type "C" \ + # N: Possible overload variants: \ + # N: def f1(x: A) -> A \ + # N: def f1(x: B) -> B \ + # N: Revealed type is "Any" + +if True: + @overload # E: Single overload definition, multiple required + def g1(x: A) -> A: ... +@overload +def f2(x: B) -> B: ... +@overload +def f2(x: C) -> C: ... +def f2(x): ... +reveal_type(f2(A())) # E: No overload variant of "f2" matches argument type "A" \ + # N: Possible overload variants: \ + # N: def f2(x: B) -> B \ + # N: def f2(x: C) -> C \ + # N: Revealed type is "Any" +reveal_type(f2(B())) # N: Revealed type is "__main__.B" + +[case testOverloadIfSplitFunctionDef] +# flags: --always-true True --always-false False +from typing import overload + +class A: ... +class B: ... +class C: ... +class D: ... + +# ----- +# Test split FuncDefs +# ----- + +@overload +def f1(x: A) -> A: ... +@overload +def f1(x: B) -> B: ... +if True: + def f1(x): ... +reveal_type(f1(A())) # N: Revealed type is "__main__.A" + +@overload +def f2(x: A) -> A: ... +@overload +def f2(x: B) -> B: ... +if False: + def f2(x): ... +else: + def f2(x): ... +reveal_type(f2(A())) # N: Revealed type is "__main__.A" + +@overload # E: An overloaded function outside a stub file must have an implementation +def f3(x: A) -> A: ... +@overload +def f3(x: B) -> B: ... +if True: + def f3(x): ... # E: Name "f3" already defined on line 31 +else: + pass # some other node + def f3(x): ... +reveal_type(f3(A())) # N: Revealed type is "__main__.A" + +[case testOverloadIfMixed] +# flags: --always-true True +from typing import overload + +class A: ... +class B: ... +class C: ... +class D: ... + +if maybe_var: # E: Name "maybe_var" is not defined + pass +if True: + @overload + def f1(x: A) -> A: ... +@overload +def f1(x: B) -> B: ... +def f1(x): ... +reveal_type(f1(A())) # N: Revealed type is "__main__.A" +reveal_type(f1(B())) # N: Revealed type is "__main__.B" + +if True: + @overload + def f2(x: A) -> A: ... + @overload + def f2(x: B) -> B: ... +def f2(x): ... +reveal_type(f2(A())) # N: Revealed type is "__main__.A" +reveal_type(f2(B())) # N: Revealed type is "__main__.B" + +if True: + @overload + def f3(x: A) -> A: ... + @overload + def f3(x: B) -> B: ... + def f3(x): ... +reveal_type(f3(A())) # N: Revealed type is "__main__.A" +reveal_type(f3(B())) # N: Revealed type is "__main__.B" From ae394cc65673972e9cce2ef66262105f57168100 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 15 Dec 2021 01:42:36 +0100 Subject: [PATCH 08/20] Fix check-functions tests --- test-data/unit/check-functions.test | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index 62107686880f..efbeebd10fee 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -1400,8 +1400,9 @@ def top() -> None: from typing import Any x = None # type: Any if x: + pass # some other node def f(): pass -def f(): pass # E: Name "f" already defined on line 4 +def f(): pass # E: Name "f" already defined on line 5 [case testIncompatibleConditionalFunctionDefinition] from typing import Any @@ -1646,8 +1647,9 @@ from typing import Any x = None # type: Any class A: if x: + pass # Some other node def f(self): pass - def f(self): pass # E: Name "f" already defined on line 5 + def f(self): pass # E: Name "f" already defined on line 6 [case testIncompatibleConditionalMethodDefinition] from typing import Any @@ -2212,7 +2214,7 @@ from typing import Callable class A: def f(self) -> None: # In particular, test that the error message contains "g" of "A". - self.g() # E: Too few arguments for "g" of "A" + self.g() # E: Too few arguments for "g" of "A" self.g(1) @dec def g(self, x: str) -> None: pass From 1dd5679e8349454e42b4c86a0fa04fed656900e4 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 15 Dec 2021 02:12:15 +0100 Subject: [PATCH 09/20] Fix crash --- mypy/fastparse.py | 2 +- test-data/unit/check-overloading.test | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 0f716db49ec2..58b3dccf24e6 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -557,7 +557,7 @@ def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]: # Multiple overloads have already been merged as OverloadedFuncDef. if not ( len(stmt.body[0].body) == 1 - or isinstance(stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef)) + and isinstance(stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef)) ): return None diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 39f0460a769e..f16cb3bfad6a 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6065,7 +6065,7 @@ else: reveal_type(f3(A())) # N: Revealed type is "__main__.A" [case testOverloadIfMixed] -# flags: --always-true True +# flags: --always-true True --always-false False from typing import overload class A: ... @@ -6101,3 +6101,13 @@ if True: def f3(x): ... reveal_type(f3(A())) # N: Revealed type is "__main__.A" reveal_type(f3(B())) # N: Revealed type is "__main__.B" + +# Don't crash with AssignmentStmt if elif +@overload # E: Single overload definition, multiple required +def f4(x: A) -> A: ... +if False: + @overload + def f4(x: B) -> B: ... +elif True: + var = 1 +def f4(x): ... # E: Name "f4" already defined on line 39 From a8c3899859ae7da526898f2c23ab1ae200f73432 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 15 Dec 2021 02:14:42 +0100 Subject: [PATCH 10/20] Remove redundant cast --- mypy/fastparse.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 58b3dccf24e6..864148e14bff 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -561,9 +561,7 @@ def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]: ): return None - overload_name = cast( - Union[Decorator, FuncDef, OverloadedFuncDef], stmt.body[0].body[0] - ).name + overload_name = stmt.body[0].body[0].name if stmt.else_body is None: return overload_name From 3a93f7ea4ec25eb99533ed6f3bfc481974b2a974 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 15 Dec 2021 03:02:37 +0100 Subject: [PATCH 11/20] Add last test cases --- test-data/unit/check-overloading.test | 36 +++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index f16cb3bfad6a..da0c0489a63e 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6111,3 +6111,39 @@ if False: elif True: var = 1 def f4(x): ... # E: Name "f4" already defined on line 39 + +if TYPE_CHECKING: + @overload + def f5(x: A) -> A: ... + @overload + def f5(x: B) -> B: ... +def f5(x): ... +reveal_type(f5(A())) # N: Revealed type is "__main__.A" +reveal_type(f5(B())) # N: Revealed type is "__main__.B" + +# Test from check-functions - testUnconditionalRedefinitionOfConditionalFunction +# If IfStmt only contains FuncDef, block is ignore if uncertain about execution +# Necessary to be able to ignore always-false cases +if maybe_true: + def f6(x): ... +def f6(x): ... + +if maybe_true: # E: Name "maybe_true" is not defined + pass # Some other node + def f7(x): ... +def f7(x): ... # E: Name "f7" already defined on line 66 + +@overload +def f8(x: A) -> A: ... +@overload +def f8(x: B) -> B: ... +if False: + def f8(x: C) -> C: ... +def f8(x): ... +reveal_type(f8(A())) # N: Revealed type is "__main__.A" +reveal_type(f8(C())) # E: No overload variant of "f8" matches argument type "C" \ + # N: Possible overload variants: \ + # N: def f8(x: A) -> A \ + # N: def f8(x: B) -> B \ + # N: Revealed type is "Any" +[typing fixtures/typing-medium.pyi] From 9a4c703806caeda1c722e7bc586d3e083a8242ad Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 15 Dec 2021 14:53:56 +0100 Subject: [PATCH 12/20] Fix tests --- test-data/unit/check-overloading.test | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index da0c0489a63e..b715a716fd60 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -5424,7 +5424,7 @@ def f_f(arg): ... [case testOverloadIfBasic] -# flags: --always-true True +# flags: --always-true True --always-false False from typing import overload class A: ... @@ -5448,7 +5448,7 @@ reveal_type(f1(B())) # N: Revealed type is "__main__.B" def f2(g: A) -> A: ... @overload def f2(g: B) -> B: ... -if not True: +if False: @overload def f2(g: C) -> C: ... def f2(g): ... @@ -5459,6 +5459,21 @@ reveal_type(f2(C())) # E: No overload variant of "f2" matches argument type "C" # N: def f2(g: B) -> B \ # N: Revealed type is "Any" +@overload +def f3(g: A) -> A: ... +@overload +def f3(g: B) -> B: ... +if maybe_true: + @overload + def f3(g: C) -> C: ... +def f3(g): ... +reveal_type(f3(A())) # N: Revealed type is "__main__.A" +reveal_type(f3(C())) # E: No overload variant of "f3" matches argument type "C" \ + # N: Possible overload variants: \ + # N: def f3(g: A) -> A \ + # N: def f3(g: B) -> B \ + # N: Revealed type is "Any" + [case testOverloadIfSysVersion] # flags: --python-version 3.9 from typing import overload @@ -5496,7 +5511,7 @@ reveal_type(f2(C())) # E: No overload variant of "f2" matches argument type "C" # N: def f2(g: A) -> A \ # N: def f2(g: B) -> B \ # N: Revealed type is "Any" -[builtins fixtures/tuple.pyi] +[builtins fixtures/ops.pyi] [case testOverloadIfMerging] # flags: --always-true True @@ -6009,7 +6024,7 @@ reveal_type(f1(C())) # E: No overload variant of "f1" matches argument type "C" if True: @overload # E: Single overload definition, multiple required - def g1(x: A) -> A: ... + def g2(x: A) -> A: ... @overload def f2(x: B) -> B: ... @overload @@ -6066,7 +6081,7 @@ reveal_type(f3(A())) # N: Revealed type is "__main__.A" [case testOverloadIfMixed] # flags: --always-true True --always-false False -from typing import overload +from typing import overload, TYPE_CHECKING class A: ... class B: ... From 2777ce196b58e22658ef0f8182eadf291d472445 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 15 Dec 2021 14:56:35 +0100 Subject: [PATCH 13/20] Typecheck skipped IfStmt conditions --- mypy/fastparse.py | 34 +++++++++++++++- test-data/unit/check-overloading.test | 57 ++++++++++++++++++++++----- 2 files changed, 80 insertions(+), 11 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 864148e14bff..648d89c8da26 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -450,6 +450,7 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: last_if_stmt: Optional[IfStmt] = None last_if_overload: Optional[Union[Decorator, FuncDef, OverloadedFuncDef]] = None last_if_stmt_overload_name: Optional[str] = None + skipped_if_stmts: List[IfStmt] = [] for stmt in stmts: if_overload_name: Optional[str] = None if_block_with_overload: Optional[Block] = None @@ -466,6 +467,8 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: if (current_overload_name is not None and isinstance(stmt, (Decorator, FuncDef)) and stmt.name == current_overload_name): + if last_if_stmt is not None: + skipped_if_stmts.append(last_if_stmt) if last_if_overload is not None: # Last stmt was an IfStmt with same overload name # Add overloads to current_overload @@ -484,6 +487,7 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: # Check if stmts are reachable and add them to current_overload, # otherwise skip IfStmt to allow subsequent overload # or function definitions. + skipped_if_stmts.append(stmt) if if_block_with_overload is None: continue if last_if_overload is not None: @@ -509,7 +513,14 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: if current_overload and current_overload_name == last_if_stmt_overload_name: # Remove last stmt (IfStmt) from ret if the overload names matched # Only happens if no executable block had been found in IfStmt - ret.pop() + skipped_if_stmts.append(cast(IfStmt, ret.pop())) + if current_overload and skipped_if_stmts: + # Add bare IfStmt (without overloads) to ret + # Required for mypy to be able to still check conditions + for if_stmt in skipped_if_stmts: + ASTConverter._strip_contents_from_if_stmt(if_stmt) + ret.append(if_stmt) + skipped_if_stmts = [] if len(current_overload) == 1: ret.append(current_overload[0]) elif len(current_overload) > 1: @@ -541,6 +552,12 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: current_overload_name = None ret.append(stmt) + if current_overload and skipped_if_stmts: + # Add bare IfStmt (without overloads) to ret + # Required for mypy to be able to still check conditions + for if_stmt in skipped_if_stmts: + ASTConverter._strip_contents_from_if_stmt(if_stmt) + ret.append(if_stmt) if len(current_overload) == 1: ret.append(current_overload[0]) elif len(current_overload) > 1: @@ -604,6 +621,21 @@ def _get_executable_if_block_with_overloads(self, stmt: IfStmt) -> Optional[Bloc return stmt.else_body return None + @staticmethod + def _strip_contents_from_if_stmt(stmt: IfStmt) -> None: + """Remove contents from IfStmt. + + Needed to still be able to check the conditions after the contents + have been merged with the surrunding function overloads. + """ + if len(stmt.body) == 1: + stmt.body[0].body = [] + if stmt.else_body and len(stmt.else_body.body) == 1: + if isinstance(stmt.else_body.body[0], IfStmt): + ASTConverter._strip_contents_from_if_stmt(stmt.else_body.body[0]) + else: + stmt.else_body.body = [] + def in_method_scope(self) -> bool: return self.class_and_function_stack[-2:] == ['C', 'F'] diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index b715a716fd60..4fc3e6503fa3 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -5463,7 +5463,7 @@ reveal_type(f2(C())) # E: No overload variant of "f2" matches argument type "C" def f3(g: A) -> A: ... @overload def f3(g: B) -> B: ... -if maybe_true: +if maybe_true: # E: Name "maybe_true" is not defined @overload def f3(g: C) -> C: ... def f3(g): ... @@ -5777,7 +5777,7 @@ reveal_type(f2(C())) # E: No overload variant of "f2" matches argument type "C" @overload # E: Single overload definition, multiple required def f3(x: A) -> A: ... -if maybe_true: +if maybe_true: # E: Name "maybe_true" is not defined @overload def f3(x: B) -> B: ... elif True: @@ -5795,7 +5795,7 @@ reveal_type(f3(B())) # E: No overload variant of "f3" matches argument type "B" @overload # E: Single overload definition, multiple required def f4(x: A) -> A: ... -if maybe_true: +if maybe_true: # E: Name "maybe_true" is not defined @overload def f4(x: B) -> B: ... else: @@ -5844,7 +5844,7 @@ def f2(x: A) -> A: ... if False: @overload def f2(x: B) -> B: ... -elif maybe_true: +elif maybe_true: # E: Name "maybe_true" is not defined @overload def f2(x: C) -> C: ... else: @@ -5859,7 +5859,7 @@ reveal_type(f2(C())) # E: No overload variant of "f2" matches argument type "C" @overload # E: Single overload definition, multiple required def f3(x: A) -> A: ... -if maybe_true: +if maybe_true: # E: Name "maybe_true" is not defined @overload def f3(x: B) -> B: ... elif False: @@ -5891,13 +5891,13 @@ class D: ... @overload # E: Single overload definition, multiple required def f1(x: A) -> A: ... -if maybe_true: +if maybe_true: # E: Name "maybe_true" is not defined @overload def f1(x: B) -> B: ... def f1(x): ... reveal_type(f1(A())) # N: Revealed type is "__main__.A" -if maybe_true: +if maybe_true: # E: Name "maybe_true" is not defined @overload def f2(x: A) -> A: ... @overload @@ -5914,14 +5914,14 @@ reveal_type(f2(A())) # E: No overload variant of "f2" matches argument type "A" if True: @overload # E: Single overload definition, multiple required def f3(x: A) -> A: ... - if maybe_true: + if maybe_true: # E: Name "maybe_true" is not defined @overload def f3(x: B) -> B: ... def f3(x): ... reveal_type(f3(A())) # N: Revealed type is "__main__.A" if True: - if maybe_true: + if maybe_true: # E: Name "maybe_true" is not defined @overload def f4(x: A) -> A: ... @overload @@ -6139,7 +6139,7 @@ reveal_type(f5(B())) # N: Revealed type is "__main__.B" # Test from check-functions - testUnconditionalRedefinitionOfConditionalFunction # If IfStmt only contains FuncDef, block is ignore if uncertain about execution # Necessary to be able to ignore always-false cases -if maybe_true: +if maybe_true: # E: Name "maybe_true" is not defined def f6(x): ... def f6(x): ... @@ -6161,4 +6161,41 @@ reveal_type(f8(C())) # E: No overload variant of "f8" matches argument type "C" # N: def f8(x: A) -> A \ # N: def f8(x: B) -> B \ # N: Revealed type is "Any" + +if maybe_true: # E: Name "maybe_true" is not defined + @overload + def f9(x: A) -> A: ... +if another_maybe_true: # E: Name "another_maybe_true" is not defined + @overload + def f9(x: B) -> B: ... +@overload +def f9(x: C) -> C: ... +@overload +def f9(x: D) -> D: ... +def f9(x): ... +reveal_type(f9(A())) # E: No overload variant of "f9" matches argument type "A" \ + # N: Possible overload variants: \ + # N: def f9(x: C) -> C \ + # N: def f9(x: D) -> D \ + # N: Revealed type is "Any" +reveal_type(f9(C())) # N: Revealed type is "__main__.C" + +if True: + if maybe_true: # E: Name "maybe_true" is not defined + @overload + def f10(x: A) -> A: ... + if another_maybe_true: # E: Name "another_maybe_true" is not defined + @overload + def f10(x: B) -> B: ... + @overload + def f10(x: C) -> C: ... + @overload + def f10(x: D) -> D: ... + def f10(x): ... +reveal_type(f10(A())) # E: No overload variant of "f10" matches argument type "A" \ + # N: Possible overload variants: \ + # N: def f10(x: C) -> C \ + # N: def f10(x: D) -> D \ + # N: Revealed type is "Any" +reveal_type(f10(C())) # N: Revealed type is "__main__.C" [typing fixtures/typing-medium.pyi] From b36a5f84704a4e9064de810ed1ec871315f18512 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 15 Dec 2021 15:52:07 +0100 Subject: [PATCH 14/20] More tests --- test-data/unit/check-overloading.test | 35 +++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 4fc3e6503fa3..7a6c78c2da0d 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6037,6 +6037,22 @@ reveal_type(f2(A())) # E: No overload variant of "f2" matches argument type "A" # N: Revealed type is "Any" reveal_type(f2(B())) # N: Revealed type is "__main__.B" +if True: + if True: + @overload # E: Single overload definition, multiple required + def g3(x: A) -> A: ... + @overload + def f3(x: B) -> B: ... + @overload + def f3(x: C) -> C: ... + def f3(x): ... +reveal_type(f3(A())) # E: No overload variant of "f3" matches argument type "A" \ + # N: Possible overload variants: \ + # N: def f3(x: B) -> B \ + # N: def f3(x: C) -> C \ + # N: Revealed type is "Any" +reveal_type(f3(B())) # N: Revealed type is "__main__.B" + [case testOverloadIfSplitFunctionDef] # flags: --always-true True --always-false False from typing import overload @@ -6198,4 +6214,23 @@ reveal_type(f10(A())) # E: No overload variant of "f10" matches argument type " # N: def f10(x: D) -> D \ # N: Revealed type is "Any" reveal_type(f10(C())) # N: Revealed type is "__main__.C" + +if some_var: # E: Name "some_var" is not defined + pass +@overload +def f11(x: A) -> A: ... +@overload +def f11(x: B) -> B: ... +def f11(x): ... +reveal_type(f11(A())) # N: Revealed type is "__main__.A" + +if True: + if some_var: # E: Name "some_var" is not defined + pass + @overload + def f12(x: A) -> A: ... + @overload + def f12(x: B) -> B: ... + def f12(x): ... +reveal_type(f12(A())) # N: Revealed type is "__main__.A" [typing fixtures/typing-medium.pyi] From 80b05d5a3432f7be9dec27d0e51610ff68a44171 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 2 Mar 2022 11:57:01 +0100 Subject: [PATCH 15/20] Apply suggestions from review --- mypy/fastparse.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 699497e33e5b..f916785d5c11 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -547,7 +547,7 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: # Add bare IfStmt (without overloads) to ret # Required for mypy to be able to still check conditions for if_stmt in skipped_if_stmts: - ASTConverter._strip_contents_from_if_stmt(if_stmt) + self._strip_contents_from_if_stmt(if_stmt) ret.append(if_stmt) skipped_if_stmts = [] if len(current_overload) == 1: @@ -585,7 +585,7 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: # Add bare IfStmt (without overloads) to ret # Required for mypy to be able to still check conditions for if_stmt in skipped_if_stmts: - ASTConverter._strip_contents_from_if_stmt(if_stmt) + self._strip_contents_from_if_stmt(if_stmt) ret.append(if_stmt) if len(current_overload) == 1: ret.append(current_overload[0]) @@ -650,18 +650,17 @@ def _get_executable_if_block_with_overloads(self, stmt: IfStmt) -> Optional[Bloc return stmt.else_body return None - @staticmethod - def _strip_contents_from_if_stmt(stmt: IfStmt) -> None: + def _strip_contents_from_if_stmt(self, stmt: IfStmt) -> None: """Remove contents from IfStmt. Needed to still be able to check the conditions after the contents - have been merged with the surrunding function overloads. + have been merged with the surrounding function overloads. """ if len(stmt.body) == 1: stmt.body[0].body = [] if stmt.else_body and len(stmt.else_body.body) == 1: if isinstance(stmt.else_body.body[0], IfStmt): - ASTConverter._strip_contents_from_if_stmt(stmt.else_body.body[0]) + self._strip_contents_from_if_stmt(stmt.else_body.body[0]) else: stmt.else_body.body = [] From 3d0397a62aba7f36154391f18075059bd63b6cba Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 2 Mar 2022 19:17:20 +0100 Subject: [PATCH 16/20] Don't merge starting If blocks without overloads --- mypy/fastparse.py | 6 +++++- test-data/unit/check-functions.test | 6 ++---- test-data/unit/check-overloading.test | 6 +++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index f916785d5c11..17bedf79053a 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -486,7 +486,11 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: if ( isinstance(stmt, IfStmt) and len(stmt.body[0].body) == 1 - and isinstance(stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef)) + and ( + isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef)) + or current_overload_name is not None + and isinstance(stmt.body[0].body[0], FuncDef) + ) ): # Check IfStmt block to determine if function overloads can be merged if_overload_name = self._check_ifstmt_for_overloads(stmt) diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index 103af1948f58..30150ca436e8 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -1400,9 +1400,8 @@ def top() -> None: from typing import Any x = None # type: Any if x: - pass # some other node def f(): pass -def f(): pass # E: Name "f" already defined on line 5 +def f(): pass # E: Name "f" already defined on line 4 [case testIncompatibleConditionalFunctionDefinition] from typing import Any @@ -1647,9 +1646,8 @@ from typing import Any x = None # type: Any class A: if x: - pass # Some other node def f(self): pass - def f(self): pass # E: Name "f" already defined on line 6 + def f(self): pass # E: Name "f" already defined on line 5 [case testIncompatibleConditionalMethodDefinition] from typing import Any diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 9b25b50ccbba..511c51e84342 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6153,11 +6153,11 @@ reveal_type(f5(A())) # N: Revealed type is "__main__.A" reveal_type(f5(B())) # N: Revealed type is "__main__.B" # Test from check-functions - testUnconditionalRedefinitionOfConditionalFunction -# If IfStmt only contains FuncDef, block is ignore if uncertain about execution -# Necessary to be able to ignore always-false cases +# Don't merge If blocks if they appear before any overloads +# and don't contain any overloads themselves. if maybe_true: # E: Name "maybe_true" is not defined def f6(x): ... -def f6(x): ... +def f6(x): ... # E: Name "f6" already defined on line 61 if maybe_true: # E: Name "maybe_true" is not defined pass # Some other node From 914d517be663daf4af166731dcabae85a6e736e5 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 2 Mar 2022 20:13:07 +0100 Subject: [PATCH 17/20] Emit error if condition can't be inferred --- mypy/fastparse.py | 49 ++++++++++++++++---- test-data/unit/check-overloading.test | 64 +++++++++++++++++++++------ 2 files changed, 91 insertions(+), 22 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 17bedf79053a..483e5eb4bc42 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -344,9 +344,19 @@ def fail(self, msg: str, line: int, column: int, - blocker: bool = True) -> None: + blocker: bool = True, + code: codes.ErrorCode = codes.SYNTAX) -> None: if blocker or not self.options.ignore_errors: - self.errors.report(line, column, msg, blocker=blocker, code=codes.SYNTAX) + self.errors.report(line, column, msg, blocker=blocker, code=code) + + def fail_merge_overload(self, node: IfStmt) -> None: + self.fail( + "Condition can't be inferred, unable to merge overloads", + line=node.line, + column=node.column, + blocker=False, + code=codes.MISC, + ) def visit(self, node: Optional[AST]) -> Any: if node is None: @@ -479,10 +489,12 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: last_if_stmt: Optional[IfStmt] = None last_if_overload: Optional[Union[Decorator, FuncDef, OverloadedFuncDef]] = None last_if_stmt_overload_name: Optional[str] = None + last_if_unknown_truth_value: Optional[IfStmt] = None skipped_if_stmts: List[IfStmt] = [] for stmt in stmts: if_overload_name: Optional[str] = None if_block_with_overload: Optional[Block] = None + if_unknown_truth_value: Optional[IfStmt] = None if ( isinstance(stmt, IfStmt) and len(stmt.body[0].body) == 1 @@ -495,7 +507,8 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: # Check IfStmt block to determine if function overloads can be merged if_overload_name = self._check_ifstmt_for_overloads(stmt) if if_overload_name is not None: - if_block_with_overload = self._get_executable_if_block_with_overloads(stmt) + if_block_with_overload, if_unknown_truth_value = \ + self._get_executable_if_block_with_overloads(stmt) if (current_overload_name is not None and isinstance(stmt, (Decorator, FuncDef)) @@ -510,6 +523,9 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: else: current_overload.append(last_if_overload) last_if_stmt, last_if_overload = None, None + if last_if_unknown_truth_value: + self.fail_merge_overload(last_if_unknown_truth_value) + last_if_unknown_truth_value = None current_overload.append(stmt) elif ( current_overload_name is not None @@ -522,6 +538,8 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: # or function definitions. skipped_if_stmts.append(stmt) if if_block_with_overload is None: + if if_unknown_truth_value is not None: + self.fail_merge_overload(if_unknown_truth_value) continue if last_if_overload is not None: # Last stmt was an IfStmt with same overload name @@ -542,6 +560,7 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: ret.append(last_if_stmt) last_if_stmt_overload_name = current_overload_name last_if_stmt, last_if_overload = None, None + last_if_unknown_truth_value = None if current_overload and current_overload_name == last_if_stmt_overload_name: # Remove last stmt (IfStmt) from ret if the overload names matched @@ -580,6 +599,7 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: Union[Decorator, FuncDef, OverloadedFuncDef], if_block_with_overload.body[0] ) + last_if_unknown_truth_value = if_unknown_truth_value else: current_overload = [] current_overload_name = None @@ -630,29 +650,40 @@ def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]: return None - def _get_executable_if_block_with_overloads(self, stmt: IfStmt) -> Optional[Block]: + def _get_executable_if_block_with_overloads( + self, stmt: IfStmt + ) -> Tuple[Optional[Block], Optional[IfStmt]]: """Return block from IfStmt that will get executed. - Only returns block if sure that alternative blocks are unreachable. + Return + 0 -> A block if sure that alternative blocks are unreachable. + 1 -> An IfStmt if the reachability of it can't be inferred, + i.e. the truth value is unknown. """ infer_reachability_of_if_statement(stmt, self.options) + if ( + stmt.else_body is None + and stmt.body[0].is_unreachable is True + ): + # always False condition with no else + return None, None if ( stmt.else_body is None or stmt.body[0].is_unreachable is False and stmt.else_body.is_unreachable is False ): # The truth value is unknown, thus not conclusive - return None + return None, stmt if stmt.else_body.is_unreachable is True: # else_body will be set unreachable if condition is always True - return stmt.body[0] + return stmt.body[0], None if stmt.body[0].is_unreachable is True: # body will be set unreachable if condition is always False # else_body can contain an IfStmt itself (for elif) -> do a recursive check if isinstance(stmt.else_body.body[0], IfStmt): return self._get_executable_if_block_with_overloads(stmt.else_body.body[0]) - return stmt.else_body - return None + return stmt.else_body, None + return None, stmt def _strip_contents_from_if_stmt(self, stmt: IfStmt) -> None: """Remove contents from IfStmt. diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 511c51e84342..331a80f7d2ac 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -5463,7 +5463,8 @@ reveal_type(f2(C())) # E: No overload variant of "f2" matches argument type "C" def f3(g: A) -> A: ... @overload def f3(g: B) -> B: ... -if maybe_true: # E: Name "maybe_true" is not defined +if maybe_true: # E: Condition cannot be inferred, unable to merge overloads \ + # E: Name "maybe_true" is not defined @overload def f3(g: C) -> C: ... def f3(g): ... @@ -5777,7 +5778,8 @@ reveal_type(f2(C())) # E: No overload variant of "f2" matches argument type "C" @overload # E: Single overload definition, multiple required def f3(x: A) -> A: ... -if maybe_true: # E: Name "maybe_true" is not defined +if maybe_true: # E: Condition cannot be inferred, unable to merge overloads \ + # E: Name "maybe_true" is not defined @overload def f3(x: B) -> B: ... elif True: @@ -5795,7 +5797,8 @@ reveal_type(f3(B())) # E: No overload variant of "f3" matches argument type "B" @overload # E: Single overload definition, multiple required def f4(x: A) -> A: ... -if maybe_true: # E: Name "maybe_true" is not defined +if maybe_true: # E: Condition cannot be inferred, unable to merge overloads \ + # E: Name "maybe_true" is not defined @overload def f4(x: B) -> B: ... else: @@ -5816,6 +5819,7 @@ class A: ... class B: ... class C: ... class D: ... +class E: ... # ----- # Match the first always-true block @@ -5844,7 +5848,8 @@ def f2(x: A) -> A: ... if False: @overload def f2(x: B) -> B: ... -elif maybe_true: # E: Name "maybe_true" is not defined +elif maybe_true: # E: Condition cannot be inferred, unable to merge overloads \ + # E: Name "maybe_true" is not defined @overload def f2(x: C) -> C: ... else: @@ -5859,7 +5864,8 @@ reveal_type(f2(C())) # E: No overload variant of "f2" matches argument type "C" @overload # E: Single overload definition, multiple required def f3(x: A) -> A: ... -if maybe_true: # E: Name "maybe_true" is not defined +if maybe_true: # E: Condition cannot be inferred, unable to merge overloads \ + # E: Name "maybe_true" is not defined @overload def f3(x: B) -> B: ... elif False: @@ -5875,6 +5881,30 @@ reveal_type(f3(B())) # E: No overload variant of "f3" matches argument type "B" # N: def f3(x: A) -> A \ # N: Revealed type is "Any" +def g(bool_var: bool) -> None: + @overload + def f4(x: A) -> A: ... + if bool_var: # E: Condition cannot be inferred, unable to merge overloads + @overload + def f4(x: B) -> B: ... + elif maybe_true: # E: Name "maybe_true" is not defined + # No 'Condition cannot be inferred' error here since it's already + # emitted on the first condition, 'bool_var', above. + @overload + def f4(x: C) -> C: ... + else: + @overload + def f4(x: D) -> D: ... + @overload + def f4(x: E) -> E: ... + def f4(x): ... + reveal_type(f4(E())) # N: Revealed type is "__main__.E" + reveal_type(f4(B())) # E: No overload variant of "f4" matches argument type "B" \ + # N: Possible overload variants: \ + # N: def f4(x: A) -> A \ + # N: def f4(x: E) -> E \ + # N: Revealed type is "Any" + [case testOverloadIfSkipUnknownExecution] # flags: --always-true True from typing import overload @@ -5891,13 +5921,15 @@ class D: ... @overload # E: Single overload definition, multiple required def f1(x: A) -> A: ... -if maybe_true: # E: Name "maybe_true" is not defined +if maybe_true: # E: Condition cannot be inferred, unable to merge overloads \ + # E: Name "maybe_true" is not defined @overload def f1(x: B) -> B: ... def f1(x): ... reveal_type(f1(A())) # N: Revealed type is "__main__.A" -if maybe_true: # E: Name "maybe_true" is not defined +if maybe_true: # E: Condition cannot be inferred, unable to merge overloads \ + # E: Name "maybe_true" is not defined @overload def f2(x: A) -> A: ... @overload @@ -5914,14 +5946,16 @@ reveal_type(f2(A())) # E: No overload variant of "f2" matches argument type "A" if True: @overload # E: Single overload definition, multiple required def f3(x: A) -> A: ... - if maybe_true: # E: Name "maybe_true" is not defined + if maybe_true: # E: Condition cannot be inferred, unable to merge overloads \ + # E: Name "maybe_true" is not defined @overload def f3(x: B) -> B: ... def f3(x): ... reveal_type(f3(A())) # N: Revealed type is "__main__.A" if True: - if maybe_true: # E: Name "maybe_true" is not defined + if maybe_true: # E: Condition cannot be inferred, unable to merge overloads \ + # E: Name "maybe_true" is not defined @overload def f4(x: A) -> A: ... @overload @@ -6178,10 +6212,12 @@ reveal_type(f8(C())) # E: No overload variant of "f8" matches argument type "C" # N: def f8(x: B) -> B \ # N: Revealed type is "Any" -if maybe_true: # E: Name "maybe_true" is not defined +if maybe_true: # E: Condition cannot be inferred, unable to merge overloads \ + # E: Name "maybe_true" is not defined @overload def f9(x: A) -> A: ... -if another_maybe_true: # E: Name "another_maybe_true" is not defined +if another_maybe_true: # E: Condition cannot be inferred, unable to merge overloads \ + # E: Name "another_maybe_true" is not defined @overload def f9(x: B) -> B: ... @overload @@ -6197,10 +6233,12 @@ reveal_type(f9(A())) # E: No overload variant of "f9" matches argument type "A" reveal_type(f9(C())) # N: Revealed type is "__main__.C" if True: - if maybe_true: # E: Name "maybe_true" is not defined + if maybe_true: # E: Condition cannot be inferred, unable to merge overloads \ + # E: Name "maybe_true" is not defined @overload def f10(x: A) -> A: ... - if another_maybe_true: # E: Name "another_maybe_true" is not defined + if another_maybe_true: # E: Condition cannot be inferred, unable to merge overloads \ + # E: Name "another_maybe_true" is not defined @overload def f10(x: B) -> B: ... @overload From 4f79d43ff5f0de96fb72ebab8a8ad93d16e8940c Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 2 Mar 2022 21:13:52 +0100 Subject: [PATCH 18/20] Add additional test cases --- test-data/unit/check-overloading.test | 30 +++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 331a80f7d2ac..376ce0e30494 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -5430,6 +5430,7 @@ from typing import overload class A: ... class B: ... class C: ... +class D: ... # ----- # Test basic overload merging @@ -5475,6 +5476,35 @@ reveal_type(f3(C())) # E: No overload variant of "f3" matches argument type "C" # N: def f3(g: B) -> B \ # N: Revealed type is "Any" +if True: + @overload + def f4(g: A) -> A: ... +if True: + @overload + def f4(g: B) -> B: ... +@overload +def f4(g: C) -> C: ... +def f4(g): ... +reveal_type(f4(A())) # N: Revealed type is "__main__.A" +reveal_type(f4(B())) # N: Revealed type is "__main__.B" +reveal_type(f4(C())) # N: Revealed type is "__main__.C" + +if True: + @overload + def f5(g: A) -> A: ... +@overload +def f5(g: B) -> B: ... +if True: + @overload + def f5(g: C) -> C: ... +@overload +def f5(g: D) -> D: ... +def f5(g): ... +reveal_type(f5(A())) # N: Revealed type is "__main__.A" +reveal_type(f5(B())) # N: Revealed type is "__main__.B" +reveal_type(f5(C())) # N: Revealed type is "__main__.C" +reveal_type(f5(D())) # N: Revealed type is "__main__.D" + [case testOverloadIfSysVersion] # flags: --python-version 3.9 from typing import overload From 4776070a1d0b211f82ca78a4c6632cd15a436e7c Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 2 Mar 2022 22:23:52 +0100 Subject: [PATCH 19/20] Add documentation --- docs/source/more_types.rst | 108 +++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/docs/source/more_types.rst b/docs/source/more_types.rst index 82a6568afcb2..12bc802e3138 100644 --- a/docs/source/more_types.rst +++ b/docs/source/more_types.rst @@ -581,6 +581,114 @@ with ``Union[int, slice]`` and ``Union[T, Sequence]``. to returning ``Any`` only if the input arguments also contain ``Any``. +Conditional overloads +--------------------- + +Sometimes it's useful to define overloads conditionally. +Common use cases are types which aren't available at runtime or only in +a certain Python version. All existing overload rules still apply. +E.g. if overloads are defined, at least 2 are required. + +.. note:: + + Mypy can only infer a limited number of conditions. + Supported ones currently include :py:data:`~typing.TYPE_CHECKING`, ``MYPY``, + :ref:`version_and_platform_checks`, and :option:`--always-true ` / :option:`--always-false ` values. + It's thus recommended to keep these conditions as simple as possible. + +.. code-block:: python + + from typing import TYPE_CHECKING, Any, overload + + if TYPE_CHECKING: + class A: ... + class B: ... + + + if TYPE_CHECKING: + @overload + def func(var: A) -> A: ... + + @overload + def func(var: B) -> B: ... + + def func(var: Any) -> Any: + return var + + + reveal_type(func(A())) # Revealed type is "A" + +.. code-block:: python + + # flags: --python-version 3.10 + import sys + from typing import Any, overload + + class A: ... + class B: ... + class C: ... + class D: ... + + + if sys.version_info < (3, 7): + @overload + def func(var: A) -> A: ... + + elif sys.version_info >= (3, 10): + @overload + def func(var: B) -> B: ... + + else: + @overload + def func(var: C) -> C: ... + + @overload + def func(var: D) -> D: ... + + def func(var: Any) -> Any: + return var + + + reveal_type(func(B())) # Revealed type is "B" + reveal_type(func(C())) # No overload variant of "func" matches argument type "C" + # Possible overload variants: + # def func(var: B) -> B + # def func(var: D) -> D + # Revealed type is "Any" + + +.. note:: + + In the last example Mypy is executed with + :option:`--python-version 3.10 `. + Because of that the condition ``sys.version_info >= (3, 10)`` will match and + the overload for ``B`` will be added. + The overloads for ``A`` and ``C`` are ignored! + The overload for ``D`` isn't defined conditionally and thus also added. + +In case Mypy can't infer a condition to be always True or always False, an error will be emitted. + +.. code-block:: python + + from typing import Any, overload + + class A: ... + class B: ... + + + def g(bool_var: bool) -> None: + if bool_var: # Condition can't be inferred, unable to merge overloads + @overload + def func(var: A) -> A: ... + + @overload + def func(var: B) -> B: ... + + def func(var: Any) -> Any: ... + + reveal_type(func(A())) # Revealed type is "Any" + + .. _advanced_self: Advanced uses of self-types From cd1e9b0d77b3c56bd86c98d20bf46a8834eb05f6 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Wed, 2 Mar 2022 21:08:59 -0800 Subject: [PATCH 20/20] Copyedits to the docs --- docs/source/more_types.rst | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/source/more_types.rst b/docs/source/more_types.rst index 12bc802e3138..11c5643705ad 100644 --- a/docs/source/more_types.rst +++ b/docs/source/more_types.rst @@ -584,17 +584,17 @@ with ``Union[int, slice]`` and ``Union[T, Sequence]``. Conditional overloads --------------------- -Sometimes it's useful to define overloads conditionally. -Common use cases are types which aren't available at runtime or only in -a certain Python version. All existing overload rules still apply. -E.g. if overloads are defined, at least 2 are required. +Sometimes it is useful to define overloads conditionally. +Common use cases include types that are unavailable at runtime or that +only exist in a certain Python version. All existing overload rules still apply. +For example, there must be at least two overloads. .. note:: Mypy can only infer a limited number of conditions. Supported ones currently include :py:data:`~typing.TYPE_CHECKING`, ``MYPY``, - :ref:`version_and_platform_checks`, and :option:`--always-true ` / :option:`--always-false ` values. - It's thus recommended to keep these conditions as simple as possible. + :ref:`version_and_platform_checks`, and :option:`--always-true ` + and :option:`--always-false ` values. .. code-block:: python @@ -659,14 +659,14 @@ E.g. if overloads are defined, at least 2 are required. .. note:: - In the last example Mypy is executed with + In the last example, mypy is executed with :option:`--python-version 3.10 `. - Because of that the condition ``sys.version_info >= (3, 10)`` will match and + Therefore, the condition ``sys.version_info >= (3, 10)`` will match and the overload for ``B`` will be added. The overloads for ``A`` and ``C`` are ignored! - The overload for ``D`` isn't defined conditionally and thus also added. + The overload for ``D`` is not defined conditionally and thus is also added. -In case Mypy can't infer a condition to be always True or always False, an error will be emitted. +When mypy cannot infer a condition to be always True or always False, an error is emitted. .. code-block:: python