From a480e644cc2f4a943ed372281e52f345cc32250e Mon Sep 17 00:00:00 2001
From: Denis Gantsev <gantsevdenis@gmail.com>
Date: Sun, 3 Nov 2019 23:56:32 +0100
Subject: [PATCH 1/7] custom type narrower

---
 mypy/checker.py | 28 ++++++++++++++++++++++++++--
 1 file changed, 26 insertions(+), 2 deletions(-)

diff --git a/mypy/checker.py b/mypy/checker.py
index 7cc1b04b5d91..d8640799e8f4 100644
--- a/mypy/checker.py
+++ b/mypy/checker.py
@@ -162,7 +162,12 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface):
     binder = None  # type: ConditionalTypeBinder
     # Helper for type checking expressions
     expr_checker = None  # type: mypy.checkexpr.ExpressionChecker
-
+    # temporary container for ctn
+    ctns_queue = None  # type:List[str]
+    # uniqueness check for ctn
+    ctns_keys = None # type:Set[str]
+    # custom type narrowers
+    ctns = None  # type:List[Tuple[str, Expression]]
     tscope = None  # type: Scope
     scope = None  # type: CheckerScope
     # Stack of function return types
@@ -231,6 +236,9 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option
         self.type_map = {}
         self.module_refs = set()
         self.pass_num = 0
+        self.ctns_queue = []
+        self.ctns_keys = set()
+        self.ctns = []
         self.current_node_deferred = False
         self.is_stub = tree.is_stub
         self.is_typeshed_stub = errors.is_typeshed_file(path)
@@ -3379,6 +3387,11 @@ def visit_decorator(self, e: Decorator) -> None:
                     e.var.type = AnyType(TypeOfAny.special_form)
                     e.var.is_ready = True
                     return
+            elif isinstance(d, CallExpr):
+                assert isinstance(d.callee, RefExpr)
+                if d.callee.fullname == 'mypy.extern.narrow_cast':
+                    # this function is a CTN
+                    self.ctns_queue.append(e.func.fullname())  # type will be added later
 
         if self.recurse_into_functions:
             with self.tscope.function_scope(e.func):
@@ -3702,6 +3715,16 @@ def find_isinstance_check(self, node: Expression
         elif is_false_literal(node):
             return None, {}
         elif isinstance(node, CallExpr):
+            expr = None
+            vartype = None
+            type = None
+            for name, expr_ in self.ctns:
+                if refers_to_fullname(node.callee, name):
+                    expr = node.args[0]
+                    if literal(expr_) == LITERAL_TYPE:
+                        vartype = type_map[node.args[0]]
+                        type = get_isinstance_type(expr_, type_map)
+                        break  # name is unique
             if refers_to_fullname(node.callee, 'builtins.isinstance'):
                 if len(node.args) != 2:  # the error will be reported elsewhere
                     return {}, {}
@@ -3709,7 +3732,8 @@ def find_isinstance_check(self, node: Expression
                 if literal(expr) == LITERAL_TYPE:
                     vartype = type_map[expr]
                     type = get_isinstance_type(node.args[1], type_map)
-                    return conditional_type_map(expr, vartype, type)
+            if expr and vartype and type:
+                return conditional_type_map(expr, vartype, type)
             elif refers_to_fullname(node.callee, 'builtins.issubclass'):
                 if len(node.args) != 2:  # the error will be reported elsewhere
                     return {}, {}

From 4c66b10e5c74bb4212259d2f87d724fbf9131630 Mon Sep 17 00:00:00 2001
From: Denis Gantsev <gantsevdenis@gmail.com>
Date: Mon, 4 Nov 2019 01:40:59 +0100
Subject: [PATCH 2/7] custom narrower

---
 mypy/checkexpr.py | 14 +++++++++++---
 mypy/extern.py    | 10 ++++++++++
 2 files changed, 21 insertions(+), 3 deletions(-)
 create mode 100644 mypy/extern.py

diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py
index 27f462b7d00f..f8abfcc1e515 100644
--- a/mypy/checkexpr.py
+++ b/mypy/checkexpr.py
@@ -270,9 +270,12 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
             typeddict_type = e.callee.node.typeddict_type.copy_modified(
                 fallback=Instance(e.callee.node, []))
             return self.check_typeddict_call(typeddict_type, e.arg_kinds, e.arg_names, e.args, e)
-        if (isinstance(e.callee, NameExpr) and e.callee.name in ('isinstance', 'issubclass')
-                and len(e.args) == 2):
-            for typ in mypy.checker.flatten(e.args[1]):
+        if (isinstance(e.callee, NameExpr)
+                and ((e.callee.name in ('isinstance', 'issubclass') and len(e.args) == 2)
+                or e.callee.name == 'narrow_cast' and len(e.args) == 1)):
+            is_narrow_cast = e.callee.name == 'narrow_cast'
+            arg = e.args[1] if not is_narrow_cast else e.args[0]
+            for typ in mypy.checker.flatten(arg):
                 node = None
                 if isinstance(typ, NameExpr):
                     try:
@@ -297,6 +300,11 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
                         self.msg.cannot_use_function_with_type(e.callee.name, "TypedDict", e)
                     elif typ.node.is_newtype:
                         self.msg.cannot_use_function_with_type(e.callee.name, "NewType", e)
+            if is_narrow_cast:
+                ctn_name = self.chk.ctns_queue.pop()
+                if ctn_name not in self.chk.ctns_keys:
+                    self.chk.ctns.append((ctn_name, arg))
+                    self.chk.ctns_keys.add(ctn_name)
         self.try_infer_partial_type(e)
         type_context = None
         if isinstance(e.callee, LambdaExpr):
diff --git a/mypy/extern.py b/mypy/extern.py
new file mode 100644
index 000000000000..973c42a9fd23
--- /dev/null
+++ b/mypy/extern.py
@@ -0,0 +1,10 @@
+from typing import Union, Tuple, Any, Callable
+
+
+# copy pasted from typeshed
+def narrow_cast(T: Union[type, Tuple[Union[type, Tuple[Any, ...]], ...]]) \
+        -> Callable[..., Callable[..., bool]]:
+    def narrow_cast_inner(f: Callable[..., bool]) -> Callable[..., bool]:
+        return f  # binds first argument of f to T
+
+    return narrow_cast_inner

From 1e4e15750919ab4a6846cac62a56c2aa43e8329a Mon Sep 17 00:00:00 2001
From: Denis Gantsev <gantsevdenis@gmail.com>
Date: Mon, 4 Nov 2019 02:01:01 +0100
Subject: [PATCH 3/7] typo

---
 mypy/checker.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mypy/checker.py b/mypy/checker.py
index d8640799e8f4..ffb76b6fc424 100644
--- a/mypy/checker.py
+++ b/mypy/checker.py
@@ -3721,7 +3721,7 @@ def find_isinstance_check(self, node: Expression
             for name, expr_ in self.ctns:
                 if refers_to_fullname(node.callee, name):
                     expr = node.args[0]
-                    if literal(expr_) == LITERAL_TYPE:
+                    if literal(expr) == LITERAL_TYPE:
                         vartype = type_map[node.args[0]]
                         type = get_isinstance_type(expr_, type_map)
                         break  # name is unique

From 28e3ea7488272423182b2a0e1275a7811d94cbe3 Mon Sep 17 00:00:00 2001
From: Denis Gantsev <gantsevdenis@gmail.com>
Date: Mon, 4 Nov 2019 20:43:22 +0100
Subject: [PATCH 4/7] add tests

---
 test-data/unit/check-isinstance.test | 49 ++++++++++++++++++++++++++++
 1 file changed, 49 insertions(+)

diff --git a/test-data/unit/check-isinstance.test b/test-data/unit/check-isinstance.test
index 2b62699b5166..ac7161dca94a 100644
--- a/test-data/unit/check-isinstance.test
+++ b/test-data/unit/check-isinstance.test
@@ -2279,3 +2279,52 @@ var = 'some string'
 if isinstance(var, *(str, int)):  # E: Too many arguments for "isinstance"
     pass
 [builtins fixtures/isinstancelist.pyi]
+
+[case testCustomTypeNarrower]
+from typing import Union, List, Tuple, Any, Callable
+from mypy.extern import narrow_cast
+@narrow_cast(int)
+def isint(x: Union[str, int], *args) -> bool:
+    return isinstance(x, int)
+
+@narrow_cast(str)
+def isstr(x, name) -> bool:
+    return name == 'STRING'
+
+u: Union[str, int] = 5
+x: Union[str, float]
+if isint(u):
+    reveal_type(u)  # N: Revealed type is 'builtins.int'
+if isinstance(u, int):
+    reveal_type(u)  # N: Revealed type is 'builtins.int'
+if isstr(x, 'STRING'):
+    x + ""
+
+@narrow_cast(str)
+def is_fizz_buzz(foo):
+    return foo in ['fizz', 'buzz']
+
+def foobar(foo: Union[str, float]):
+    if foo in ['fizz', 'buzz']:
+        reveal_type(foo)  # N: Revealed type is 'Union[builtins.str, builtins.float]'
+    if is_fizz_buzz(foo):
+        reveal_type(foo)  # N: Revealed type is 'builtins.str'
+
+@narrow_cast((str, (int,)))
+def is_str_or_int(x):
+    return isinstance(x, (str, (int,)))
+
+@narrow_cast((str, (list,)))
+def is_str_or_list(x):
+    return isinstance(x, (str, (list,)))
+
+def f(x: Union[int, str, List]) -> None:
+    if isinstance(x, (str, (int,))):
+        reveal_type(x)  # N: Revealed type is 'Union[builtins.int, builtins.str]'
+    if is_str_or_int(x):
+        reveal_type(x)  # N: Revealed type is 'Union[builtins.int, builtins.str]'
+    if isinstance(x, (str, (list,))):
+        reveal_type(x)  # N: Revealed type is 'Union[builtins.str, builtins.list[Any]]'
+    if is_str_or_list(x):
+        reveal_type(x)   # N: Revealed type is 'Union[builtins.str, builtins.list[Any]]'
+[builtins fixtures/isinstancelist.pyi]

From 6ca754216a1a38db208e1a48d6a42c9f40e7ac0c Mon Sep 17 00:00:00 2001
From: Denis Gantsev <gantsevdenis@gmail.com>
Date: Mon, 4 Nov 2019 20:44:51 +0100
Subject: [PATCH 5/7] import && stub

---
 mypy/checkexpr.py                        |  7 ++++---
 test-data/unit/lib-stub/mypy/__init__.py |  0
 test-data/unit/lib-stub/mypy/extern.py   | 10 ++++++++++
 3 files changed, 14 insertions(+), 3 deletions(-)
 create mode 100644 test-data/unit/lib-stub/mypy/__init__.py
 create mode 100644 test-data/unit/lib-stub/mypy/extern.py

diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py
index f8abfcc1e515..096e147a3fb6 100644
--- a/mypy/checkexpr.py
+++ b/mypy/checkexpr.py
@@ -61,6 +61,7 @@
     tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound,
     function_type, callable_type, try_getting_str_literals
 )
+from mypy.semanal import refers_to_fullname
 import mypy.errorcodes as codes
 
 # Type of callback user for checking individual function arguments. See
@@ -272,8 +273,8 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
             return self.check_typeddict_call(typeddict_type, e.arg_kinds, e.arg_names, e.args, e)
         if (isinstance(e.callee, NameExpr)
                 and ((e.callee.name in ('isinstance', 'issubclass') and len(e.args) == 2)
-                or e.callee.name == 'narrow_cast' and len(e.args) == 1)):
-            is_narrow_cast = e.callee.name == 'narrow_cast'
+                or refers_to_fullname(e.callee, 'mypy.extern.narrow_cast') and len(e.args) == 1)):
+            is_narrow_cast = refers_to_fullname(e.callee, 'mypy.extern.narrow_cast')
             arg = e.args[1] if not is_narrow_cast else e.args[0]
             for typ in mypy.checker.flatten(arg):
                 node = None
@@ -3316,7 +3317,7 @@ def visit_super_expr(self, e: SuperExpr) -> Type:
             self.chk.fail(message_registry.SUPER_ARG_2_NOT_INSTANCE_OF_ARG_1, e)
             return AnyType(TypeOfAny.from_error)
 
-        for base in mro[index+1:]:
+        for base in mro[index + 1:]:
             if e.name in base.names or base == mro[-1]:
                 if e.info and e.info.fallback_to_any and base == mro[-1]:
                     # There's an undefined base class, and we're at the end of the
diff --git a/test-data/unit/lib-stub/mypy/__init__.py b/test-data/unit/lib-stub/mypy/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/test-data/unit/lib-stub/mypy/extern.py b/test-data/unit/lib-stub/mypy/extern.py
new file mode 100644
index 000000000000..973c42a9fd23
--- /dev/null
+++ b/test-data/unit/lib-stub/mypy/extern.py
@@ -0,0 +1,10 @@
+from typing import Union, Tuple, Any, Callable
+
+
+# copy pasted from typeshed
+def narrow_cast(T: Union[type, Tuple[Union[type, Tuple[Any, ...]], ...]]) \
+        -> Callable[..., Callable[..., bool]]:
+    def narrow_cast_inner(f: Callable[..., bool]) -> Callable[..., bool]:
+        return f  # binds first argument of f to T
+
+    return narrow_cast_inner

From ee48fb436c60702e802294e6ab03d68e1137e71d Mon Sep 17 00:00:00 2001
From: Denis Gantsev <gantsevdenis@gmail.com>
Date: Mon, 4 Nov 2019 20:48:32 +0100
Subject: [PATCH 6/7] comment

---
 mypy/extern.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mypy/extern.py b/mypy/extern.py
index 973c42a9fd23..d796dd15a1d6 100644
--- a/mypy/extern.py
+++ b/mypy/extern.py
@@ -1,7 +1,7 @@
 from typing import Union, Tuple, Any, Callable
 
 
-# copy pasted from typeshed
+# signature is partially copy pasted from typeshed isinstance
 def narrow_cast(T: Union[type, Tuple[Union[type, Tuple[Any, ...]], ...]]) \
         -> Callable[..., Callable[..., bool]]:
     def narrow_cast_inner(f: Callable[..., bool]) -> Callable[..., bool]:

From e4574eee1a44643d7917f4c3f48705b5759a252d Mon Sep 17 00:00:00 2001
From: Denis Gantsev <gantsevdenis@gmail.com>
Date: Mon, 4 Nov 2019 21:31:38 +0100
Subject: [PATCH 7/7] flake8

---
 mypy/checker.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mypy/checker.py b/mypy/checker.py
index ffb76b6fc424..63ae070a6eb0 100644
--- a/mypy/checker.py
+++ b/mypy/checker.py
@@ -165,9 +165,9 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface):
     # temporary container for ctn
     ctns_queue = None  # type:List[str]
     # uniqueness check for ctn
-    ctns_keys = None # type:Set[str]
+    ctns_keys = None  # type: Set[str]
     # custom type narrowers
-    ctns = None  # type:List[Tuple[str, Expression]]
+    ctns = None  # type: List[Tuple[str, Expression]]
     tscope = None  # type: Scope
     scope = None  # type: CheckerScope
     # Stack of function return types