Skip to content

Commit 21725e9

Browse files
authored
Merge pull request #4285 from kchmck/fix-4046
Fix problems with running tests in package `__init__` files (#4046)
2 parents 48f52b1 + 5197354 commit 21725e9

File tree

5 files changed

+36
-20
lines changed

5 files changed

+36
-20
lines changed

changelog/4046.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix problems with running tests in package ``__init__.py`` files.

src/_pytest/main.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -281,15 +281,6 @@ def pytest_ignore_collect(path, config):
281281
if _in_venv(path) and not allow_in_venv:
282282
return True
283283

284-
# Skip duplicate paths.
285-
keepduplicates = config.getoption("keepduplicates")
286-
duplicate_paths = config.pluginmanager._duplicatepaths
287-
if not keepduplicates:
288-
if path in duplicate_paths:
289-
return True
290-
else:
291-
duplicate_paths.add(path)
292-
293284
return False
294285

295286

@@ -551,14 +542,32 @@ def _collect(self, arg):
551542
col = root._collectfile(argpath)
552543
if col:
553544
self._node_cache[argpath] = col
554-
for y in self.matchnodes(col, names):
545+
m = self.matchnodes(col, names)
546+
# If __init__.py was the only file requested, then the matched node will be
547+
# the corresponding Package, and the first yielded item will be the __init__
548+
# Module itself, so just use that. If this special case isn't taken, then all
549+
# the files in the package will be yielded.
550+
if argpath.basename == "__init__.py":
551+
yield next(m[0].collect())
552+
return
553+
for y in m:
555554
yield y
556555

557556
def _collectfile(self, path):
558557
ihook = self.gethookproxy(path)
559558
if not self.isinitpath(path):
560559
if ihook.pytest_ignore_collect(path=path, config=self.config):
561560
return ()
561+
562+
# Skip duplicate paths.
563+
keepduplicates = self.config.getoption("keepduplicates")
564+
if not keepduplicates:
565+
duplicate_paths = self.config.pluginmanager._duplicatepaths
566+
if path in duplicate_paths:
567+
return ()
568+
else:
569+
duplicate_paths.add(path)
570+
562571
return ihook.pytest_collect_file(path=path, parent=self)
563572

564573
def _recurse(self, path):

src/_pytest/python.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -553,15 +553,6 @@ def isinitpath(self, path):
553553
return path in self.session._initialpaths
554554

555555
def collect(self):
556-
# XXX: HACK!
557-
# Before starting to collect any files from this package we need
558-
# to cleanup the duplicate paths added by the session's collect().
559-
# Proper fix is to not track these as duplicates in the first place.
560-
for path in list(self.session.config.pluginmanager._duplicatepaths):
561-
# if path.parts()[:len(self.fspath.dirpath().parts())] == self.fspath.dirpath().parts():
562-
if path.dirname.startswith(self.name):
563-
self.session.config.pluginmanager._duplicatepaths.remove(path)
564-
565556
this_path = self.fspath.dirpath()
566557
init_module = this_path.join("__init__.py")
567558
if init_module.check(file=1) and path_matches_patterns(

testing/test_collection.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,21 @@ def test_collect_init_tests(testdir):
957957
"*<Function 'test_foo'>",
958958
]
959959
)
960+
result = testdir.runpytest("./tests", "--collect-only")
961+
result.stdout.fnmatch_lines(
962+
[
963+
"*<Module '__init__.py'>",
964+
"*<Function 'test_init'>",
965+
"*<Module 'test_foo.py'>",
966+
"*<Function 'test_foo'>",
967+
]
968+
)
969+
result = testdir.runpytest("./tests/test_foo.py", "--collect-only")
970+
result.stdout.fnmatch_lines(["*<Module 'test_foo.py'>", "*<Function 'test_foo'>"])
971+
assert "test_init" not in result.stdout.str()
972+
result = testdir.runpytest("./tests/__init__.py", "--collect-only")
973+
result.stdout.fnmatch_lines(["*<Module '__init__.py'>", "*<Function 'test_init'>"])
974+
assert "test_foo" not in result.stdout.str()
960975

961976

962977
def test_collect_invalid_signature_message(testdir):

testing/test_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ class TestY(TestX):
219219
started = reprec.getcalls("pytest_collectstart")
220220
finished = reprec.getreports("pytest_collectreport")
221221
assert len(started) == len(finished)
222-
assert len(started) == 7 # XXX extra TopCollector
222+
assert len(started) == 8
223223
colfail = [x for x in finished if x.failed]
224224
assert len(colfail) == 1
225225

0 commit comments

Comments
 (0)