Skip to content

[mypyc] Fix using values from other modules that were reexported #7496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 16 additions & 37 deletions mypyc/genops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def f(x: int) -> int:
from typing import (
TypeVar, Callable, Dict, List, Tuple, Optional, Union, Sequence, Set, Any, cast
)
from typing_extensions import overload, ClassVar, NoReturn
from typing_extensions import overload, NoReturn
from collections import OrderedDict
from abc import abstractmethod
import sys
import importlib.util
import itertools

Expand Down Expand Up @@ -192,7 +192,7 @@ def build_ir(modules: List[MypyFile],
builder = IRBuilder(types, graph, errors, mapper, module_names, pbv, options)
builder.visit_mypy_file(module)
module_ir = ModuleIR(
builder.imports,
list(builder.imports),
builder.functions,
builder.classes,
builder.final_names
Expand Down Expand Up @@ -1050,7 +1050,10 @@ def __init__(self,

self.errors = errors
self.mapper = mapper
self.imports = [] # type: List[str]
# Notionally a list of all of the modules imported by the
# module being compiled, but stored as an OrderedDict so we
# can also do quick lookups.
self.imports = OrderedDict() # type: OrderedDict[str, None]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document this attribute. Maybe this would benefit from renaming -- it took me a while to figure out what it's used for.


def visit_mypy_file(self, mypyfile: MypyFile) -> None:
if mypyfile.fullname() in ('typing', 'abc'):
Expand Down Expand Up @@ -1518,29 +1521,8 @@ def allocate_class(self, cdef: ClassDef) -> None:
[self.load_globals_dict(), self.load_static_unicode(cdef.name),
tp], cdef.line)

# An unfortunate hack: for some stdlib modules, pull in modules
# that the stubs reexport things from. This works around #393
# in these cases.
import_maps = {
'os': tuple(['os.path'] + ([] if sys.platform == 'win32' else ['posix'])),
'os.path': ('os',),
'tokenize': ('token',),
'weakref': ('_weakref',),
'asyncio': ('asyncio.events', 'asyncio.tasks',),
'click': ('click.core', 'click.termui', 'click.decorators',
'click.exceptions', 'click.types'),
'ast': ('_ast',),
} # type: ClassVar[Dict[str, Sequence[str]]]

def gen_import(self, id: str, line: int) -> None:
if id in IRBuilder.import_maps:
for dep in IRBuilder.import_maps[id]:
self._gen_import(dep, line)

self._gen_import(id, line)

def _gen_import(self, id: str, line: int) -> None:
self.imports.append(id)
self.imports[id] = None

needs_import, out = BasicBlock(), BasicBlock()
first_load = self.add(LoadStatic(object_rprimitive, 'module', id))
Expand Down Expand Up @@ -2817,7 +2799,7 @@ def visit_name_expr(self, expr: NameExpr) -> Value:
if value is not None:
return value

if isinstance(expr.node, MypyFile):
if isinstance(expr.node, MypyFile) and expr.node.fullname() in self.imports:
return self.load_module(expr.node.fullname())

# If the expression is locally defined, then read the result from the corresponding
Expand Down Expand Up @@ -2853,11 +2835,11 @@ def visit_member_expr(self, expr: MemberExpr) -> Value:
if value is not None:
return value

if self.is_module_member_expr(expr):
return self.load_module_attr(expr)
else:
obj = self.accept(expr.expr)
return self.get_attr(obj, expr.name, self.node_type(expr), expr.line)
if isinstance(expr.node, MypyFile) and expr.node.fullname() in self.imports:
return self.load_module(expr.node.fullname())

obj = self.accept(expr.expr)
return self.get_attr(obj, expr.name, self.node_type(expr), expr.line)

def get_attr(self, obj: Value, attr: str, result_type: RType, line: int) -> Value:
if (isinstance(obj.type, RInstance) and obj.type.class_ir.is_ext_class
Expand Down Expand Up @@ -5210,7 +5192,8 @@ def load_global(self, expr: NameExpr) -> Value:
"""
# If the global is from 'builtins', turn it into a module attr load instead
if self.is_builtin_ref_expr(expr):
return self.load_module_attr(expr)
assert expr.node, "RefExpr not resolved"
return self.load_module_attr_by_fullname(expr.node.fullname(), expr.line)
if (self.is_native_module_ref_expr(expr) and isinstance(expr.node, TypeInfo)
and not self.is_synthetic_type(expr.node)):
assert expr.fullname is not None
Expand Down Expand Up @@ -5257,10 +5240,6 @@ def load_static_unicode(self, value: str) -> Value:
def load_module(self, name: str) -> Value:
return self.add(LoadStatic(object_rprimitive, 'module', name))

def load_module_attr(self, expr: RefExpr) -> Value:
assert expr.node, "RefExpr not resolved"
return self.load_module_attr_by_fullname(expr.node.fullname(), expr.line)

def load_module_attr_by_fullname(self, fullname: str, line: int) -> Value:
module, _, name = fullname.rpartition('.')
left = self.load_module(module)
Expand Down
28 changes: 28 additions & 0 deletions mypyc/test-data/run.test
Original file line number Diff line number Diff line change
Expand Up @@ -4645,3 +4645,31 @@ from native import foo

assert foo(None) == None
assert foo([1, 2, 3]) == ((1, 2, 3), [1, 2, 3])

[case testReexport]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also test re-exporting a module or a class. What about re-exporting a global variable?

# Test that we properly handle accessing values that have been reexported
import a
def f(x: int) -> int:
return a.g(x) + a.foo + a.b.foo

whatever = a.A()

[file a.py]
from b import g as g, A as A, foo as foo
import b

[file b.py]
def g(x: int) -> int:
return x + 1

class A:
pass

foo = 20

[file driver.py]
from native import f, whatever
import b

assert f(20) == 61
assert isinstance(whatever, b.A)