diff --git a/src/_pytest/compat.py b/src/_pytest/compat.py index 1845d9d91ef..0688c2c0ecb 100644 --- a/src/_pytest/compat.py +++ b/src/_pytest/compat.py @@ -34,6 +34,7 @@ if TYPE_CHECKING: + from types import ModuleType # noqa: F401 (used in type string) from typing import Type # noqa: F401 (used in type string) @@ -336,28 +337,28 @@ def safe_isclass(obj: object) -> bool: return False -COLLECT_FAKEMODULE_ATTRIBUTES = ( - "Collector", - "Module", - "Function", - "Instance", - "Session", - "Item", - "Class", - "File", - "_fillfuncargs", -) - - -def _setup_collect_fakemodule() -> None: +def _setup_collect_fakemodule() -> "ModuleType": + """Setup pytest.collect fake module for backward compatibility.""" from types import ModuleType - import pytest + import _pytest.nodes + + collect_fakemodule_attributes = ( + ("Collector", _pytest.nodes.Collector), + ("Module", _pytest.python.Module), + ("Function", _pytest.python.Function), + ("Instance", _pytest.python.Instance), + ("Session", _pytest.main.Session), + ("Item", _pytest.nodes.Item), + ("Class", _pytest.python.Class), + ("File", _pytest.nodes.File), + ("_fillfuncargs", _pytest.fixtures.fillfixtures), + ) - # Types ignored because the module is created dynamically. - pytest.collect = ModuleType("pytest.collect") # type: ignore - pytest.collect.__all__ = [] # type: ignore # used for setns - for attr_name in COLLECT_FAKEMODULE_ATTRIBUTES: - setattr(pytest.collect, attr_name, getattr(pytest, attr_name)) # type: ignore + mod = ModuleType("pytest.collect") + mod.__all__ = [] # type: ignore # used for setns (obsolete?) + for attr_name, value in collect_fakemodule_attributes: + setattr(mod, attr_name, value) + return mod class CaptureIO(io.TextIOWrapper): diff --git a/src/pytest/__init__.py b/src/pytest/__init__.py index 33bc3d0fbe5..6ea2cf4cf45 100644 --- a/src/pytest/__init__.py +++ b/src/pytest/__init__.py @@ -2,9 +2,10 @@ """ pytest: unit and functional testing with Python. """ +import sys + from _pytest import __version__ from _pytest.assertion import register_assert_rewrite -from _pytest.compat import _setup_collect_fakemodule from _pytest.config import cmdline from _pytest.config import ExitCode from _pytest.config import hookimpl @@ -46,7 +47,6 @@ from _pytest.warning_types import PytestUnknownMarkWarning from _pytest.warning_types import PytestWarning - set_trace = __pytestPDB.set_trace __all__ = [ @@ -95,5 +95,18 @@ ] -_setup_collect_fakemodule() -del _setup_collect_fakemodule +if sys.version_info >= (3, 7): + + def __getattr__(name): + if name == "collect": + from _pytest.compat import _setup_collect_fakemodule + + return _setup_collect_fakemodule() + raise AttributeError(name) + + +else: + from _pytest.compat import _setup_collect_fakemodule + + collect = _setup_collect_fakemodule() + del _setup_collect_fakemodule diff --git a/testing/conftest.py b/testing/conftest.py index 90cdcb869fd..f03835e19b2 100644 --- a/testing/conftest.py +++ b/testing/conftest.py @@ -1,3 +1,4 @@ +import os import re import sys from typing import List @@ -136,6 +137,27 @@ def testdir(testdir: Testdir) -> Testdir: return testdir +@pytest.fixture +def symlink_or_skip(): + """Return a function that creates a symlink or raises ``Skip``. + + On Windows `os.symlink` is available, but normal users require special + admin privileges to create symlinks. + """ + + def wrap_os_symlink(src, dst, *args, **kwargs): + if os.path.islink(dst): + return + + try: + os.symlink(src, dst, *args, **kwargs) + except OSError as e: + pytest.skip("os.symlink({!r}) failed: {!r}".format((src, dst), e)) + assert os.path.islink(dst) + + return wrap_os_symlink + + @pytest.fixture(scope="session") def color_mapping(): """Returns a utility class which can replace keys in strings in the form "{NAME}" diff --git a/testing/test_meta.py b/testing/test_meta.py index ffc8fd38aba..6dbc47ddc41 100644 --- a/testing/test_meta.py +++ b/testing/test_meta.py @@ -10,6 +10,7 @@ import _pytest import pytest +from _pytest.pytester import Testdir def _modules(): @@ -33,3 +34,39 @@ def test_no_warnings(module): "-c", "import {}".format(module), )) # fmt: on + + +def test_pytest_collect_attribute(_sys_snapshot): + from types import ModuleType + + del sys.modules["pytest"] + + import pytest + + assert isinstance(pytest.collect, ModuleType) + assert pytest.collect.Item is pytest.Item + + with pytest.raises(ImportError): + import pytest.collect + + if sys.version_info >= (3, 7): + with pytest.raises(AttributeError, match=r"^doesnotexist$"): + pytest.doesnotexist + else: + with pytest.raises(AttributeError, match=r"doesnotexist"): + pytest.doesnotexist + + +def test_pytest_circular_import(testdir: Testdir, symlink_or_skip) -> None: + """Importing pytest should not import pytest itself.""" + import pytest + import os.path + + symlink_or_skip(os.path.dirname(pytest.__file__), "another") + + del sys.modules["pytest"] + + testdir.syspathinsert() + import another # noqa: F401 + + assert "pytest" not in sys.modules