Skip to content

Adding relative import support (#60) #379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 7, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sys
from os.path import dirname, basename

from typing import Undefined, Dict, List, Tuple, cast, Set
from typing import Undefined, Dict, List, Tuple, cast, Set, Union

from mypy.types import Type
from mypy.nodes import MypyFile, Node, Import, ImportFrom, ImportAll
Expand Down Expand Up @@ -451,21 +451,36 @@ def all_imported_modules_in_file(self,
Return list of tuples (module id, import line number) for all modules
imported in file.
"""
def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str:
"""Function to correct for relative imports."""
file_id = file.fullname()
rel = imp.relative
if rel == 0:
return imp.id
if os.path.basename(file.path) == '__init__.py':
rel -= 1
if rel != 0:
file_id = ".".join(file_id.split(".")[:-rel])
new_id = file_id + "." + imp.id if imp.id else file_id

return new_id

res = List[Tuple[str, int]]()
for imp in file.imports:
if not imp.is_unreachable:
if isinstance(imp, Import):
for id, _ in imp.ids:
res.append((id, imp.line))
elif isinstance(imp, ImportFrom):
res.append((imp.id, imp.line))
cur_id = correct_rel_imp(imp)
res.append((cur_id, imp.line))
# Also add any imported names that are submodules.
for name, __ in imp.names:
sub_id = imp.id + '.' + name
sub_id = cur_id + '.' + name
if self.is_module(sub_id):
res.append((sub_id, imp.line))
elif isinstance(imp, ImportAll):
res.append((imp.id, imp.line))
res.append((correct_rel_imp(imp), imp.line))
return res

def is_module(self, id: str) -> bool:
Expand Down
4 changes: 3 additions & 1 deletion mypy/noderepr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,18 @@ def __init__(self, import_tok: Any, components: List[List[Token]],
class ImportFromRepr:
def __init__(self,
from_tok: Any,
rel_toks: Any,
components: List[Token],
import_tok: Any,
lparen: Any,
names: List[Tuple[List[Token], Token]],
rparen: Any, br: Any) -> None:
# Notes:
# - lparen and rparen may be empty
# - lparen, rparen, and rel_tok may be empty
# - in each names tuple, the first item contains tokens for
# 'name [as name]' and the second item is a comma or empty.
self.from_tok = from_tok
self.rel_toks = rel_toks
self.components = components
self.import_tok = import_tok
self.lparen = lparen
Expand Down
14 changes: 10 additions & 4 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Abstract syntax tree node classes (i.e. parse tree)."""

import os
import re
from abc import abstractmethod, ABCMeta

Expand Down Expand Up @@ -140,6 +141,9 @@ def fullname(self) -> str:
def accept(self, visitor: NodeVisitor[T]) -> T:
return visitor.visit_mypy_file(self)

def is_package_init_file(self) -> bool:
return not (self.path is None) and len(self.path) != 0 \
and os.path.basename(self.path) == '__init__.py'

class ImportBase(Node):
"""Base class for all import statements."""
Expand All @@ -163,20 +167,22 @@ class ImportFrom(ImportBase):

names = Undefined(List[Tuple[str, str]]) # Tuples (name, as name)

def __init__(self, id: str, names: List[Tuple[str, str]]) -> None:
def __init__(self, id: str, relative: int, names: List[Tuple[str, str]]) -> None:
self.id = id
self.names = names

self.relative = relative

def accept(self, visitor: NodeVisitor[T]) -> T:
return visitor.visit_import_from(self)


class ImportAll(ImportBase):
"""from m import *"""

def __init__(self, id: str) -> None:
def __init__(self, id: str, relative: int) -> None:
self.id = id

self.relative = relative

def accept(self, visitor: NodeVisitor[T]) -> T:
return visitor.visit_import_all(self)

Expand Down
1 change: 1 addition & 0 deletions mypy/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def visit_import_all(self, o):
def output_import_from_or_all(self, o):
r = o.repr
self.token(r.from_tok)
self.tokens(r.rel_toks)
self.tokens(r.components)
self.token(r.import_tok)
self.token(r.lparen)
Expand Down
26 changes: 22 additions & 4 deletions mypy/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,17 +170,35 @@ def parse_import(self) -> Import:

def parse_import_from(self) -> Node:
from_tok = self.expect('from')
name, components = self.parse_qualified_name()

# Build the list of beginning relative tokens.
relative = 0
rel_toks = List[Token]()
while self.current_str() == ".":
rel_toks.append(self.expect('.'))
relative += 1

# Parse qualified name to actually import from.
if self.current_str() == "import":
# Empty/defualt values.
name = ""
components = List[Token]()
else:
name, components = self.parse_qualified_name()

if name == self.custom_typing_module:
name = 'typing'

# Parse import list
import_tok = self.expect('import')
name_toks = List[Tuple[List[Token], Token]]()
lparen = none
rparen = none
node = None # type: ImportBase
if self.current_str() == '*':
# An import all from a module node:
name_toks.append(([self.skip()], none))
node = ImportAll(name)
node = ImportAll(name, relative)
else:
is_paren = self.current_str() == '('
if is_paren:
Expand All @@ -206,12 +224,12 @@ def parse_import_from(self) -> Node:
if is_paren:
rparen = self.expect(')')
if node is None:
node = ImportFrom(name, targets)
node = ImportFrom(name, relative, targets)
br = self.expect_break()
self.imports.append(node)
# TODO: Fix representation if there is a custom typing module import.
self.set_repr(node, noderepr.ImportFromRepr(
from_tok, components, import_tok, lparen, name_toks, rparen, br))
from_tok, rel_toks, components, import_tok, lparen, name_toks, rparen, br))
if name == '__future__':
self.future_options.extend(target[0] for target in targets)
return node
Expand Down
30 changes: 25 additions & 5 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ def __init__(self, lib_path: List[str], errors: Errors,

def visit_file(self, file_node: MypyFile, fnam: str) -> None:
self.errors.set_file(fnam)
self.globals = file_node.names
self.cur_mod_node = file_node
self.cur_mod_id = file_node.fullname()
self.globals = file_node.names

if 'builtins' in self.modules:
self.globals['__builtins__'] = SymbolTableNode(
Expand Down Expand Up @@ -605,8 +606,9 @@ def add_module_symbol(self, id: str, as_id: str, context: Context) -> None:
self.add_unknown_symbol(as_id, context)

def visit_import_from(self, i: ImportFrom) -> None:
if i.id in self.modules:
m = self.modules[i.id]
i_id = self.correct_relative_import(i)
if i_id in self.modules:
m = self.modules[i_id]
for id, as_id in i.names:
node = m.names.get(id, None)
if node:
Expand All @@ -628,9 +630,27 @@ def normalize_type_alias(self, node: SymbolTableNode,
node = self.lookup_qualified(type_aliases[node.fullname], ctx)
return node

def correct_relative_import(self, node: Union[ImportFrom, ImportAll]) -> str:
if node.relative == 0:
return node.id

parts = self.cur_mod_id.split(".")
cur_mod_id = self.cur_mod_id

rel = node.relative
if self.cur_mod_node.is_package_init_file():
rel -= 1
if len(parts) < rel:
self.fail("Relative import climbs too many namespaces.", node)
if rel != 0:
cur_mod_id = ".".join(parts[:-rel])

return cur_mod_id + (("." + node.id) if node.id else "")

def visit_import_all(self, i: ImportAll) -> None:
if i.id in self.modules:
m = self.modules[i.id]
i_id = self.correct_relative_import(i)
if i_id in self.modules:
m = self.modules[i_id]
for name, node in m.names.items():
node = self.normalize_type_alias(node, i)
if not name.startswith('_'):
Expand Down
4 changes: 2 additions & 2 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def visit_import_from(self, o):
a = []
for name, as_name in o.names:
a.append('{} : {}'.format(name, as_name))
return 'ImportFrom:{}({}, [{}])'.format(o.line, o.id, ', '.join(a))
return 'ImportFrom:{}({}, [{}])'.format(o.line, "." * o.relative + o.id, ', '.join(a))

def visit_import_all(self, o):
return 'ImportAll:{}({})'.format(o.line, o.id)
return 'ImportAll:{}({})'.format(o.line, "." * o.relative + o.id)

# Definitions

Expand Down
109 changes: 109 additions & 0 deletions mypy/test/data/semanal-modules.test
Original file line number Diff line number Diff line change
Expand Up @@ -632,3 +632,112 @@ MypyFile:1(
AssignmentStmt:1(
NameExpr(y* [x.y])
IntExpr(1)))

[case testRelativeImport0]
import m.x
m.x.z.y
[file m/__init__.py]
[file m/x.py]
from . import z
[file m/z.py]
y = 1
[out]
MypyFile:1(
Import:1(m.x : m.x)
ExpressionStmt:2(
MemberExpr:2(
MemberExpr:2(
MemberExpr:2(
NameExpr(m)
x [m.x])
z [m.z])
y [m.z.y])))
MypyFile:1(
tmp/m/x.py
ImportFrom:1(., [z : z]))
MypyFile:1(
tmp/m/z.py
AssignmentStmt:1(
NameExpr(y* [m.z.y])
IntExpr(1)))

[case testRelativeImport1]
import m.t.b as b
b.x.y
b.z.y
[file m/__init__.py]
[file m/x.py]
y = 1
[file m/z.py]
y = 3
[file m/t/__init__.py]
[file m/t/b.py]
from .. import x, z
[out]
MypyFile:1(
Import:1(m.t.b : b)
ExpressionStmt:2(
MemberExpr:2(
MemberExpr:2(
NameExpr(b [m.t.b])
x [m.x])
y [m.x.y]))
ExpressionStmt:3(
MemberExpr:3(
MemberExpr:3(
NameExpr(b [m.t.b])
z [m.z])
y [m.z.y])))
MypyFile:1(
tmp/m/t/b.py
ImportFrom:1(.., [x : x, z : z]))
MypyFile:1(
tmp/m/x.py
AssignmentStmt:1(
NameExpr(y* [m.x.y])
IntExpr(1)))
MypyFile:1(
tmp/m/z.py
AssignmentStmt:1(
NameExpr(y* [m.z.y])
IntExpr(3)))

[case testRelativeImport2]
import m.t.b as b
b.xy
b.zy
[file m/__init__.py]
[file m/x.py]
y = 1
[file m/z.py]
y = 3
[file m/t/__init__.py]
[file m/t/b.py]
from ..x import y as xy
from ..z import y as zy
[out]
MypyFile:1(
Import:1(m.t.b : b)
ExpressionStmt:2(
MemberExpr:2(
NameExpr(b [m.t.b])
xy [m.x.y]))
ExpressionStmt:3(
MemberExpr:3(
NameExpr(b [m.t.b])
zy [m.z.y])))
MypyFile:1(
tmp/m/t/b.py
ImportFrom:1(..x, [y : xy])
ImportFrom:2(..z, [y : zy]))
MypyFile:1(
tmp/m/x.py
AssignmentStmt:1(
NameExpr(y* [m.x.y])
IntExpr(1)))
MypyFile:1(
tmp/m/z.py
AssignmentStmt:1(
NameExpr(y* [m.z.y])
IntExpr(3)))

4 changes: 2 additions & 2 deletions mypy/treetransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def visit_import(self, node: Import) -> Node:
return Import(node.ids[:])

def visit_import_from(self, node: ImportFrom) -> Node:
return ImportFrom(node.id, node.names[:])
return ImportFrom(node.id, node.relative, node.names[:])

def visit_import_all(self, node: ImportAll) -> Node:
return ImportAll(node.id)
return ImportAll(node.id, node.relative)

def visit_func_def(self, node: FuncDef) -> FuncDef:
# Note that a FuncDef must be transformed to a FuncDef.
Expand Down