diff --git a/mypy/build.py b/mypy/build.py index 1d515ac9caeb..ab1ddc984a50 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -16,7 +16,7 @@ import sys from os.path import dirname, basename -from typing import Undefined, Dict, List, Tuple, cast, Set +from typing import Undefined, Dict, List, Tuple, cast, Set, Union from mypy.types import Type from mypy.nodes import MypyFile, Node, Import, ImportFrom, ImportAll @@ -451,6 +451,20 @@ def all_imported_modules_in_file(self, Return list of tuples (module id, import line number) for all modules imported in file. """ + def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str: + """Function to correct for relative imports.""" + file_id = file.fullname() + rel = imp.relative + if rel == 0: + return imp.id + if os.path.basename(file.path) == '__init__.py': + rel -= 1 + if rel != 0: + file_id = ".".join(file_id.split(".")[:-rel]) + new_id = file_id + "." + imp.id if imp.id else file_id + + return new_id + res = List[Tuple[str, int]]() for imp in file.imports: if not imp.is_unreachable: @@ -458,14 +472,15 @@ def all_imported_modules_in_file(self, for id, _ in imp.ids: res.append((id, imp.line)) elif isinstance(imp, ImportFrom): - res.append((imp.id, imp.line)) + cur_id = correct_rel_imp(imp) + res.append((cur_id, imp.line)) # Also add any imported names that are submodules. for name, __ in imp.names: - sub_id = imp.id + '.' + name + sub_id = cur_id + '.' + name if self.is_module(sub_id): res.append((sub_id, imp.line)) elif isinstance(imp, ImportAll): - res.append((imp.id, imp.line)) + res.append((correct_rel_imp(imp), imp.line)) return res def is_module(self, id: str) -> bool: diff --git a/mypy/noderepr.py b/mypy/noderepr.py index 7bdf89befe30..4da110fce91b 100644 --- a/mypy/noderepr.py +++ b/mypy/noderepr.py @@ -34,16 +34,18 @@ def __init__(self, import_tok: Any, components: List[List[Token]], class ImportFromRepr: def __init__(self, from_tok: Any, + rel_toks: Any, components: List[Token], import_tok: Any, lparen: Any, names: List[Tuple[List[Token], Token]], rparen: Any, br: Any) -> None: # Notes: - # - lparen and rparen may be empty + # - lparen, rparen, and rel_tok may be empty # - in each names tuple, the first item contains tokens for # 'name [as name]' and the second item is a comma or empty. self.from_tok = from_tok + self.rel_toks = rel_toks self.components = components self.import_tok = import_tok self.lparen = lparen diff --git a/mypy/nodes.py b/mypy/nodes.py index 78acee6002c3..996caae28ac8 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1,5 +1,6 @@ """Abstract syntax tree node classes (i.e. parse tree).""" +import os import re from abc import abstractmethod, ABCMeta @@ -140,6 +141,9 @@ def fullname(self) -> str: def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_mypy_file(self) + def is_package_init_file(self) -> bool: + return not (self.path is None) and len(self.path) != 0 \ + and os.path.basename(self.path) == '__init__.py' class ImportBase(Node): """Base class for all import statements.""" @@ -163,10 +167,11 @@ class ImportFrom(ImportBase): names = Undefined(List[Tuple[str, str]]) # Tuples (name, as name) - def __init__(self, id: str, names: List[Tuple[str, str]]) -> None: + def __init__(self, id: str, relative: int, names: List[Tuple[str, str]]) -> None: self.id = id self.names = names - + self.relative = relative + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_import_from(self) @@ -174,9 +179,10 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class ImportAll(ImportBase): """from m import *""" - def __init__(self, id: str) -> None: + def __init__(self, id: str, relative: int) -> None: self.id = id - + self.relative = relative + def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_import_all(self) diff --git a/mypy/output.py b/mypy/output.py index 7352cd038227..7c37558ee8fe 100644 --- a/mypy/output.py +++ b/mypy/output.py @@ -55,6 +55,7 @@ def visit_import_all(self, o): def output_import_from_or_all(self, o): r = o.repr self.token(r.from_tok) + self.tokens(r.rel_toks) self.tokens(r.components) self.token(r.import_tok) self.token(r.lparen) diff --git a/mypy/parse.py b/mypy/parse.py index dd1bdbdca8fd..62e6759db6d3 100644 --- a/mypy/parse.py +++ b/mypy/parse.py @@ -170,17 +170,35 @@ def parse_import(self) -> Import: def parse_import_from(self) -> Node: from_tok = self.expect('from') - name, components = self.parse_qualified_name() + + # Build the list of beginning relative tokens. + relative = 0 + rel_toks = List[Token]() + while self.current_str() == ".": + rel_toks.append(self.expect('.')) + relative += 1 + + # Parse qualified name to actually import from. + if self.current_str() == "import": + # Empty/defualt values. + name = "" + components = List[Token]() + else: + name, components = self.parse_qualified_name() + if name == self.custom_typing_module: name = 'typing' + + # Parse import list import_tok = self.expect('import') name_toks = List[Tuple[List[Token], Token]]() lparen = none rparen = none node = None # type: ImportBase if self.current_str() == '*': + # An import all from a module node: name_toks.append(([self.skip()], none)) - node = ImportAll(name) + node = ImportAll(name, relative) else: is_paren = self.current_str() == '(' if is_paren: @@ -206,12 +224,12 @@ def parse_import_from(self) -> Node: if is_paren: rparen = self.expect(')') if node is None: - node = ImportFrom(name, targets) + node = ImportFrom(name, relative, targets) br = self.expect_break() self.imports.append(node) # TODO: Fix representation if there is a custom typing module import. self.set_repr(node, noderepr.ImportFromRepr( - from_tok, components, import_tok, lparen, name_toks, rparen, br)) + from_tok, rel_toks, components, import_tok, lparen, name_toks, rparen, br)) if name == '__future__': self.future_options.extend(target[0] for target in targets) return node diff --git a/mypy/semanal.py b/mypy/semanal.py index e4f19fca69f0..8df8f3aab646 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -143,8 +143,9 @@ def __init__(self, lib_path: List[str], errors: Errors, def visit_file(self, file_node: MypyFile, fnam: str) -> None: self.errors.set_file(fnam) - self.globals = file_node.names + self.cur_mod_node = file_node self.cur_mod_id = file_node.fullname() + self.globals = file_node.names if 'builtins' in self.modules: self.globals['__builtins__'] = SymbolTableNode( @@ -605,8 +606,9 @@ def add_module_symbol(self, id: str, as_id: str, context: Context) -> None: self.add_unknown_symbol(as_id, context) def visit_import_from(self, i: ImportFrom) -> None: - if i.id in self.modules: - m = self.modules[i.id] + i_id = self.correct_relative_import(i) + if i_id in self.modules: + m = self.modules[i_id] for id, as_id in i.names: node = m.names.get(id, None) if node: @@ -628,9 +630,27 @@ def normalize_type_alias(self, node: SymbolTableNode, node = self.lookup_qualified(type_aliases[node.fullname], ctx) return node + def correct_relative_import(self, node: Union[ImportFrom, ImportAll]) -> str: + if node.relative == 0: + return node.id + + parts = self.cur_mod_id.split(".") + cur_mod_id = self.cur_mod_id + + rel = node.relative + if self.cur_mod_node.is_package_init_file(): + rel -= 1 + if len(parts) < rel: + self.fail("Relative import climbs too many namespaces.", node) + if rel != 0: + cur_mod_id = ".".join(parts[:-rel]) + + return cur_mod_id + (("." + node.id) if node.id else "") + def visit_import_all(self, i: ImportAll) -> None: - if i.id in self.modules: - m = self.modules[i.id] + i_id = self.correct_relative_import(i) + if i_id in self.modules: + m = self.modules[i_id] for name, node in m.names.items(): node = self.normalize_type_alias(node, i) if not name.startswith('_'): diff --git a/mypy/strconv.py b/mypy/strconv.py index 78da8396bfb4..d64e0f72ddd1 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -92,10 +92,10 @@ def visit_import_from(self, o): a = [] for name, as_name in o.names: a.append('{} : {}'.format(name, as_name)) - return 'ImportFrom:{}({}, [{}])'.format(o.line, o.id, ', '.join(a)) + return 'ImportFrom:{}({}, [{}])'.format(o.line, "." * o.relative + o.id, ', '.join(a)) def visit_import_all(self, o): - return 'ImportAll:{}({})'.format(o.line, o.id) + return 'ImportAll:{}({})'.format(o.line, "." * o.relative + o.id) # Definitions diff --git a/mypy/test/data/semanal-modules.test b/mypy/test/data/semanal-modules.test index 284f4e5f501c..f89197e28830 100644 --- a/mypy/test/data/semanal-modules.test +++ b/mypy/test/data/semanal-modules.test @@ -632,3 +632,112 @@ MypyFile:1( AssignmentStmt:1( NameExpr(y* [x.y]) IntExpr(1))) + +[case testRelativeImport0] +import m.x +m.x.z.y +[file m/__init__.py] +[file m/x.py] +from . import z +[file m/z.py] +y = 1 +[out] +MypyFile:1( + Import:1(m.x : m.x) + ExpressionStmt:2( + MemberExpr:2( + MemberExpr:2( + MemberExpr:2( + NameExpr(m) + x [m.x]) + z [m.z]) + y [m.z.y]))) +MypyFile:1( + tmp/m/x.py + ImportFrom:1(., [z : z])) +MypyFile:1( + tmp/m/z.py + AssignmentStmt:1( + NameExpr(y* [m.z.y]) + IntExpr(1))) + +[case testRelativeImport1] +import m.t.b as b +b.x.y +b.z.y +[file m/__init__.py] +[file m/x.py] +y = 1 +[file m/z.py] +y = 3 +[file m/t/__init__.py] +[file m/t/b.py] +from .. import x, z +[out] +MypyFile:1( + Import:1(m.t.b : b) + ExpressionStmt:2( + MemberExpr:2( + MemberExpr:2( + NameExpr(b [m.t.b]) + x [m.x]) + y [m.x.y])) + ExpressionStmt:3( + MemberExpr:3( + MemberExpr:3( + NameExpr(b [m.t.b]) + z [m.z]) + y [m.z.y]))) +MypyFile:1( + tmp/m/t/b.py + ImportFrom:1(.., [x : x, z : z])) +MypyFile:1( + tmp/m/x.py + AssignmentStmt:1( + NameExpr(y* [m.x.y]) + IntExpr(1))) +MypyFile:1( + tmp/m/z.py + AssignmentStmt:1( + NameExpr(y* [m.z.y]) + IntExpr(3))) + +[case testRelativeImport2] +import m.t.b as b +b.xy +b.zy +[file m/__init__.py] +[file m/x.py] +y = 1 +[file m/z.py] +y = 3 +[file m/t/__init__.py] +[file m/t/b.py] +from ..x import y as xy +from ..z import y as zy +[out] +MypyFile:1( + Import:1(m.t.b : b) + ExpressionStmt:2( + MemberExpr:2( + NameExpr(b [m.t.b]) + xy [m.x.y])) + ExpressionStmt:3( + MemberExpr:3( + NameExpr(b [m.t.b]) + zy [m.z.y]))) +MypyFile:1( + tmp/m/t/b.py + ImportFrom:1(..x, [y : xy]) + ImportFrom:2(..z, [y : zy])) +MypyFile:1( + tmp/m/x.py + AssignmentStmt:1( + NameExpr(y* [m.x.y]) + IntExpr(1))) +MypyFile:1( + tmp/m/z.py + AssignmentStmt:1( + NameExpr(y* [m.z.y]) + IntExpr(3))) + \ No newline at end of file diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 9c67e8cc7593..43c3dcda24f5 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -60,10 +60,10 @@ def visit_import(self, node: Import) -> Node: return Import(node.ids[:]) def visit_import_from(self, node: ImportFrom) -> Node: - return ImportFrom(node.id, node.names[:]) + return ImportFrom(node.id, node.relative, node.names[:]) def visit_import_all(self, node: ImportAll) -> Node: - return ImportAll(node.id) + return ImportAll(node.id, node.relative) def visit_func_def(self, node: FuncDef) -> FuncDef: # Note that a FuncDef must be transformed to a FuncDef.