Skip to content

Improve import performance of assertion rewrite. Fixes #3918. #3919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 5, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ env/
.ropeproject
.idea
.hypothesis
.pydevproject
.project
.settings
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ Endre Galaczi
Eric Hunsberger
Eric Siegerman
Erik M. Bray
Fabio Zadrozny
Feng Ma
Florian Bruhin
Floris Bruynooghe
Expand Down
1 change: 1 addition & 0 deletions changelog/3918.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve performance of assertion rewriting.
69 changes: 63 additions & 6 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,24 @@ def __init__(self, config):
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
# which might result in infinite recursion (#3506)
self._writing_pyc = False
self._basenames_to_check_rewrite = {"conftest"}
self._marked_for_rewrite_cache = {}
self._session_paths_checked = False

def set_session(self, session):
self.session = session
self._session_paths_checked = False

def _imp_find_module(self, name, path=None):
"""Indirection so we can mock calls to find_module originated from the hook during testing"""
return imp.find_module(name, path)

def find_module(self, name, path=None):
if self._writing_pyc:
return None
state = self.config._assertstate
if self._early_rewrite_bailout(name, state):
return None
state.trace("find_module called for: %s" % name)
names = name.rsplit(".", 1)
lastname = names[-1]
Expand All @@ -87,7 +97,7 @@ def find_module(self, name, path=None):
pth = path[0]
if pth is None:
try:
fd, fn, desc = imp.find_module(lastname, path)
fd, fn, desc = self._imp_find_module(lastname, path)
except ImportError:
return None
if fd is not None:
Expand Down Expand Up @@ -166,6 +176,44 @@ def find_module(self, name, path=None):
self.modules[name] = co, pyc
return self

def _early_rewrite_bailout(self, name, state):
"""
This is a fast way to get out of rewriting modules. Profiling has
shown that the call to imp.find_module (inside of the find_module
from this class) is a major slowdown, so, this method tries to
filter what we're sure won't be rewritten before getting to it.
"""
if self.session is not None and not self._session_paths_checked:
self._session_paths_checked = True
for path in self.session._initialpaths:
# Make something as c:/projects/my_project/path.py ->
# ['c:', 'projects', 'my_project', 'path.py']
parts = str(path).split(os.path.sep)
# add 'path' to basenames to be checked.
self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])

# Note: conftest already by default in _basenames_to_check_rewrite.
parts = name.split(".")
if parts[-1] in self._basenames_to_check_rewrite:
return False

# For matching the name it must be as if it was a filename.
parts[-1] = parts[-1] + ".py"
fn_pypath = py.path.local(os.path.sep.join(parts))
for pat in self.fnpats:
# if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
# on the name alone because we need to match against the full path
if os.path.dirname(pat):
return False
if fn_pypath.fnmatch(pat):
return False

if self._is_marked_for_rewrite(name, state):
return False

state.trace("early skip of rewriting module: %s" % (name,))
return True

def _should_rewrite(self, name, fn_pypath, state):
# always rewrite conftest files
fn = str(fn_pypath)
Expand All @@ -185,12 +233,20 @@ def _should_rewrite(self, name, fn_pypath, state):
state.trace("matched test file %r" % (fn,))
return True

for marked in self._must_rewrite:
if name == marked or name.startswith(marked + "."):
state.trace("matched marked file %r (from %r)" % (name, marked))
return True
return self._is_marked_for_rewrite(name, state)

return False
def _is_marked_for_rewrite(self, name, state):
try:
return self._marked_for_rewrite_cache[name]
except KeyError:
for marked in self._must_rewrite:
if name == marked or name.startswith(marked + "."):
state.trace("matched marked file %r (from %r)" % (name, marked))
self._marked_for_rewrite_cache[name] = True
return True

self._marked_for_rewrite_cache[name] = False
return False

def mark_rewrite(self, *names):
"""Mark import names as needing to be rewritten.
Expand All @@ -207,6 +263,7 @@ def mark_rewrite(self, *names):
):
self._warn_already_imported(name)
self._must_rewrite.update(names)
self._marked_for_rewrite_cache.clear()

def _warn_already_imported(self, name):
self.config.warn(
Expand Down
7 changes: 4 additions & 3 deletions src/_pytest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def __init__(self, config):
self.trace = config.trace.root.get("collection")
self._norecursepatterns = config.getini("norecursedirs")
self.startdir = py.path.local()
self._initialpaths = frozenset()
# Keep track of any collected nodes in here, so we don't duplicate fixtures
self._node_cache = {}

Expand Down Expand Up @@ -441,13 +442,14 @@ def _perform_collect(self, args, genitems):
self.trace("perform_collect", self, args)
self.trace.root.indent += 1
self._notfound = []
self._initialpaths = set()
initialpaths = []
self._initialparts = []
self.items = items = []
for arg in args:
parts = self._parsearg(arg)
self._initialparts.append(parts)
self._initialpaths.add(parts[0])
initialpaths.append(parts[0])
self._initialpaths = frozenset(initialpaths)
rep = collect_one_node(self)
self.ihook.pytest_collectreport(report=rep)
self.trace.root.indent -= 1
Expand Down Expand Up @@ -564,7 +566,6 @@ def _tryconvertpyarg(self, x):
"""Convert a dotted module name to path.

"""

try:
with _patched_find_module():
loader = pkgutil.find_loader(x)
Expand Down
107 changes: 93 additions & 14 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,22 +1106,21 @@ def test_ternary_display():


class TestIssue2121:
def test_simple(self, testdir):
testdir.tmpdir.join("tests/file.py").ensure().write(
"""
def test_simple_failure():
assert 1 + 1 == 3
"""
)
testdir.tmpdir.join("pytest.ini").write(
textwrap.dedent(
def test_rewrite_python_files_contain_subdirs(self, testdir):
testdir.makepyfile(
**{
"tests/file.py": """
def test_simple_failure():
assert 1 + 1 == 3
"""
[pytest]
python_files = tests/**.py
"""
)
}
)
testdir.makeini(
"""
[pytest]
python_files = tests/**.py
"""
)

result = testdir.runpytest()
result.stdout.fnmatch_lines("*E*assert (1 + 1) == 3")

Expand Down Expand Up @@ -1153,3 +1152,83 @@ def spy_write_pyc(*args, **kwargs):
hook = AssertionRewritingHook(pytestconfig)
assert hook.find_module("test_foo") is not None
assert len(write_pyc_called) == 1


class TestEarlyRewriteBailout(object):
@pytest.fixture
def hook(self, pytestconfig, monkeypatch, testdir):
"""Returns a patched AssertionRewritingHook instance so we can configure its initial paths and track
if imp.find_module has been called.
"""
import imp

self.find_module_calls = []
self.initial_paths = set()

class StubSession(object):
_initialpaths = self.initial_paths

def isinitpath(self, p):
return p in self._initialpaths

def spy_imp_find_module(name, path):
self.find_module_calls.append(name)
return imp.find_module(name, path)

hook = AssertionRewritingHook(pytestconfig)
# use default patterns, otherwise we inherit pytest's testing config
hook.fnpats[:] = ["test_*.py", "*_test.py"]
monkeypatch.setattr(hook, "_imp_find_module", spy_imp_find_module)
hook.set_session(StubSession())
testdir.syspathinsert()
return hook

def test_basic(self, testdir, hook):
"""
Ensure we avoid calling imp.find_module when we know for sure a certain module will not be rewritten
to optimize assertion rewriting (#3918).
"""
testdir.makeconftest(
"""
import pytest
@pytest.fixture
def fix(): return 1
"""
)
testdir.makepyfile(test_foo="def test_foo(): pass")
testdir.makepyfile(bar="def bar(): pass")
foobar_path = testdir.makepyfile(foobar="def foobar(): pass")
self.initial_paths.add(foobar_path)

# conftest files should always be rewritten
assert hook.find_module("conftest") is not None
assert self.find_module_calls == ["conftest"]

# files matching "python_files" mask should always be rewritten
assert hook.find_module("test_foo") is not None
assert self.find_module_calls == ["conftest", "test_foo"]

# file does not match "python_files": early bailout
assert hook.find_module("bar") is None
assert self.find_module_calls == ["conftest", "test_foo"]

# file is an initial path (passed on the command-line): should be rewritten
assert hook.find_module("foobar") is not None
assert self.find_module_calls == ["conftest", "test_foo", "foobar"]

def test_pattern_contains_subdirectories(self, testdir, hook):
"""If one of the python_files patterns contain subdirectories ("tests/**.py") we can't bailout early
because we need to match with the full path, which can only be found by calling imp.find_module.
"""
p = testdir.makepyfile(
**{
"tests/file.py": """
def test_simple_failure():
assert 1 + 1 == 3
"""
}
)
testdir.syspathinsert(p.dirpath())
hook.fnpats[:] = ["tests/**.py"]
assert hook.find_module("file") is not None
assert self.find_module_calls == ["file"]