Skip to content

Commit 1fb67d0

Browse files
authored
Add a new, lower priority for imports inside "if MYPY" (#2167)
This improves the processing order of modules involved in a cycle when one arc of the cycle only exists inside `if MYPY` or `if typing.TYPE_CHECKING`. Unittest by @elazarg.
1 parent 6998787 commit 1fb67d0

File tree

4 files changed

+95
-25
lines changed

4 files changed

+95
-25
lines changed

mypy/build.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from typing import (AbstractSet, Dict, Iterable, Iterator, List,
2424
NamedTuple, Optional, Set, Tuple, Union)
2525

26-
from mypy.nodes import (MypyFile, Import, ImportFrom, ImportAll)
26+
from mypy.nodes import (MypyFile, Node, ImportBase, Import, ImportFrom, ImportAll)
2727
from mypy.semanal import FirstPass, SemanticAnalyzer, ThirdPass
2828
from mypy.checker import TypeChecker
2929
from mypy.indirection import TypeIndirectionVisitor
@@ -307,10 +307,23 @@ def default_lib_path(data_dir: str,
307307
PRI_HIGH = 5 # top-level "from X import blah"
308308
PRI_MED = 10 # top-level "import X"
309309
PRI_LOW = 20 # either form inside a function
310+
PRI_MYPY = 25 # inside "if MYPY" or "if TYPE_CHECKING"
310311
PRI_INDIRECT = 30 # an indirect dependency
311312
PRI_ALL = 99 # include all priorities
312313

313314

315+
def import_priority(imp: ImportBase, toplevel_priority: int) -> int:
316+
"""Compute import priority from an import node."""
317+
if not imp.is_top_level:
318+
# Inside a function
319+
return PRI_LOW
320+
if imp.is_mypy_only:
321+
# Inside "if MYPY" or "if typing.TYPE_CHECKING"
322+
return max(PRI_MYPY, toplevel_priority)
323+
# A regular import; priority determined by argument.
324+
return toplevel_priority
325+
326+
314327
# TODO: Get rid of all_types. It's not used except for one log message.
315328
# Maybe we could instead publish a map from module ID to its type_map.
316329
class BuildManager:
@@ -396,20 +409,21 @@ def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str:
396409
for imp in file.imports:
397410
if not imp.is_unreachable:
398411
if isinstance(imp, Import):
399-
pri = PRI_MED if imp.is_top_level else PRI_LOW
412+
pri = import_priority(imp, PRI_MED)
413+
ancestor_pri = import_priority(imp, PRI_LOW)
400414
for id, _ in imp.ids:
401415
ancestor_parts = id.split(".")[:-1]
402416
ancestors = []
403417
for part in ancestor_parts:
404418
ancestors.append(part)
405-
res.append((PRI_LOW, ".".join(ancestors), imp.line))
419+
res.append((ancestor_pri, ".".join(ancestors), imp.line))
406420
res.append((pri, id, imp.line))
407421
elif isinstance(imp, ImportFrom):
408422
cur_id = correct_rel_imp(imp)
409423
pos = len(res)
410424
all_are_submodules = True
411425
# Also add any imported names that are submodules.
412-
pri = PRI_MED if imp.is_top_level else PRI_LOW
426+
pri = import_priority(imp, PRI_MED)
413427
for name, __ in imp.names:
414428
sub_id = cur_id + '.' + name
415429
if self.is_module(sub_id):
@@ -422,10 +436,10 @@ def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str:
422436
# cur_id is also a dependency, and we should
423437
# insert it *before* any submodules.
424438
if not all_are_submodules:
425-
pri = PRI_HIGH if imp.is_top_level else PRI_LOW
439+
pri = import_priority(imp, PRI_HIGH)
426440
res.insert(pos, ((pri, cur_id, imp.line)))
427441
elif isinstance(imp, ImportAll):
428-
pri = PRI_HIGH if imp.is_top_level else PRI_LOW
442+
pri = import_priority(imp, PRI_HIGH)
429443
res.append((pri, correct_rel_imp(imp), imp.line))
430444

431445
return res
@@ -1704,8 +1718,8 @@ def order_ascc(graph: Graph, ascc: AbstractSet[str], pri_max: int = PRI_ALL) ->
17041718
each SCC thus found. The recursion is bounded because at each
17051719
recursion the spread in priorities is (at least) one less.
17061720
1707-
In practice there are only a few priority levels (currently
1708-
N=3) and in the worst case we just carry out the same algorithm
1721+
In practice there are only a few priority levels (less than a
1722+
dozen) and in the worst case we just carry out the same algorithm
17091723
for finding SCCs N times. Thus the complexity is no worse than
17101724
the complexity of the original SCC-finding algorithm -- see
17111725
strongly_connected_components() below for a reference.

mypy/nodes.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,11 @@ def deserialize(cls, data: JsonDict) -> 'MypyFile':
284284

285285
class ImportBase(Statement):
286286
"""Base class for all import statements."""
287-
is_unreachable = False
288-
is_top_level = False # Set by semanal.FirstPass
287+
288+
is_unreachable = False # Set by semanal.FirstPass if inside `if False` etc.
289+
is_top_level = False # Ditto if outside any class or def
290+
is_mypy_only = False # Ditto if inside `if TYPE_CHECKING` or `if MYPY`
291+
289292
# If an import replaces existing definitions, we construct dummy assignment
290293
# statements that assign the imported names to the names in the current scope,
291294
# for type checking purposes. Example:

mypy/semanal.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,20 @@
8484
T = TypeVar('T')
8585

8686

87-
# Inferred value of an expression.
88-
ALWAYS_TRUE = 0
89-
ALWAYS_FALSE = 1
90-
TRUTH_VALUE_UNKNOWN = 2
87+
# Inferred truth value of an expression.
88+
ALWAYS_TRUE = 1
89+
MYPY_TRUE = 2 # True in mypy, False at runtime
90+
ALWAYS_FALSE = 3
91+
MYPY_FALSE = 4 # False in mypy, True at runtime
92+
TRUTH_VALUE_UNKNOWN = 5
93+
94+
inverted_truth_mapping = {
95+
ALWAYS_TRUE: ALWAYS_FALSE,
96+
ALWAYS_FALSE: ALWAYS_TRUE,
97+
TRUTH_VALUE_UNKNOWN: TRUTH_VALUE_UNKNOWN,
98+
MYPY_TRUE: MYPY_FALSE,
99+
MYPY_FALSE: MYPY_TRUE,
100+
}
91101

92102
# Map from obsolete name to the current spelling.
93103
obsolete_name_mapping = {
@@ -3124,12 +3134,16 @@ def infer_reachability_of_if_statement(s: IfStmt,
31243134
platform: str) -> None:
31253135
for i in range(len(s.expr)):
31263136
result = infer_if_condition_value(s.expr[i], pyversion, platform)
3127-
if result == ALWAYS_FALSE:
3128-
# The condition is always false, so we skip the if/elif body.
3137+
if result in (ALWAYS_FALSE, MYPY_FALSE):
3138+
# The condition is considered always false, so we skip the if/elif body.
31293139
mark_block_unreachable(s.body[i])
3130-
elif result == ALWAYS_TRUE:
3131-
# This condition is always true, so all of the remaining
3132-
# elif/else bodies will never be executed.
3140+
elif result in (ALWAYS_TRUE, MYPY_TRUE):
3141+
# This condition is considered always true, so all of the remaining
3142+
# elif/else bodies should not be checked.
3143+
if result == MYPY_TRUE:
3144+
# This condition is false at runtime; this will affect
3145+
# import priorities.
3146+
mark_block_mypy_only(s.body[i])
31333147
for body in s.body[i + 1:]:
31343148
mark_block_unreachable(body)
31353149
if s.else_body:
@@ -3141,7 +3155,8 @@ def infer_if_condition_value(expr: Expression, pyversion: Tuple[int, int], platf
31413155
"""Infer whether if condition is always true/false.
31423156
31433157
Return ALWAYS_TRUE if always true, ALWAYS_FALSE if always false,
3144-
and TRUTH_VALUE_UNKNOWN otherwise.
3158+
MYPY_TRUE if true under mypy and false at runtime, MYPY_FALSE if
3159+
false under mypy and true at runtime, else TRUTH_VALUE_UNKNOWN.
31453160
"""
31463161
name = ''
31473162
negated = False
@@ -3165,12 +3180,9 @@ def infer_if_condition_value(expr: Expression, pyversion: Tuple[int, int], platf
31653180
elif name == 'PY3':
31663181
result = ALWAYS_TRUE if pyversion[0] == 3 else ALWAYS_FALSE
31673182
elif name == 'MYPY' or name == 'TYPE_CHECKING':
3168-
result = ALWAYS_TRUE
3183+
result = MYPY_TRUE
31693184
if negated:
3170-
if result == ALWAYS_TRUE:
3171-
result = ALWAYS_FALSE
3172-
elif result == ALWAYS_FALSE:
3173-
result = ALWAYS_TRUE
3185+
result = inverted_truth_mapping[result]
31743186
return result
31753187

31763188

@@ -3345,6 +3357,23 @@ def visit_import_all(self, node: ImportAll) -> None:
33453357
node.is_unreachable = True
33463358

33473359

3360+
def mark_block_mypy_only(block: Block) -> None:
3361+
block.accept(MarkImportsMypyOnlyVisitor())
3362+
3363+
3364+
class MarkImportsMypyOnlyVisitor(TraverserVisitor):
3365+
"""Visitor that sets is_mypy_only (which affects priority)."""
3366+
3367+
def visit_import(self, node: Import) -> None:
3368+
node.is_mypy_only = True
3369+
3370+
def visit_import_from(self, node: ImportFrom) -> None:
3371+
node.is_mypy_only = True
3372+
3373+
def visit_import_all(self, node: ImportAll) -> None:
3374+
node.is_mypy_only = True
3375+
3376+
33483377
def is_identity_signature(sig: Type) -> bool:
33493378
"""Is type a callable of form T -> T (where T is a type variable)?"""
33503379
if isinstance(sig, CallableType) and sig.arg_kinds == [ARG_POS]:

test-data/unit/check-modules.test

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,3 +1313,27 @@ pass
13131313
[file b]
13141314
pass
13151315
[out]
1316+
1317+
[case testTypeCheckPrio]
1318+
# cmd: mypy -m part1 part2 part3 part4
1319+
1320+
[file part1.py]
1321+
from part3 import Thing
1322+
class FirstThing: pass
1323+
1324+
[file part2.py]
1325+
from part4 import part4_thing as Thing
1326+
1327+
[file part3.py]
1328+
from part2 import Thing
1329+
reveal_type(Thing)
1330+
1331+
[file part4.py]
1332+
from typing import TYPE_CHECKING
1333+
if TYPE_CHECKING:
1334+
from part1 import FirstThing
1335+
def part4_thing(a: int) -> str: pass
1336+
1337+
[builtins fixtures/bool.pyi]
1338+
[out]
1339+
tmp/part3.py:2: error: Revealed type is 'def (a: builtins.int) -> builtins.str'

0 commit comments

Comments
 (0)