diff --git a/_pytest/mark.py b/_pytest/mark.py index 50581e0a890..bd1c61bf48b 100644 --- a/_pytest/mark.py +++ b/_pytest/mark.py @@ -1,5 +1,7 @@ """ generic mechanism for marking and selecting python functions. """ import inspect +import types +from copy import copy class MarkerError(Exception): @@ -253,10 +255,17 @@ def __call__(self, *args, **kwargs): """ if passed a single callable argument: decorate it with mark info. otherwise add *args/**kwargs in-place to mark information. """ if args and not kwargs: - func = args[0] - is_class = inspect.isclass(func) - if len(args) == 1 and (istestfunc(func) or is_class): + orig_func = args[0] + is_class = inspect.isclass(orig_func) + if len(args) == 1 and (istestfunc(orig_func) or is_class): if is_class: + methods_dict = dict(orig_func.__dict__) + methods_dict.pop('__weakref__', None) + methods_dict.pop('__dict__', None) + func = type(orig_func.__name__, + orig_func.__bases__, + methods_dict) + if hasattr(func, 'pytestmark'): mark_list = func.pytestmark if not isinstance(mark_list, list): @@ -268,6 +277,13 @@ def __call__(self, *args, **kwargs): else: func.pytestmark = [self] else: + func = types.FunctionType(orig_func.__code__, + orig_func.__globals__, + orig_func.__name__, + orig_func.__defaults__, + orig_func.__closure__) + func.__dict__ = copy(orig_func.__dict__) + holder = getattr(func, self.name, None) if holder is None: holder = MarkInfo( @@ -276,6 +292,7 @@ def __call__(self, *args, **kwargs): setattr(func, self.name, holder) else: holder.add(self.args, self.kwargs) + return func kw = self.kwargs.copy() kw.update(kwargs) diff --git a/_pytest/python.py b/_pytest/python.py index d5612a584d2..f6194f640c6 100644 --- a/_pytest/python.py +++ b/_pytest/python.py @@ -9,7 +9,7 @@ import py import pytest from _pytest._code.code import TerminalRepr -from _pytest.mark import MarkDecorator, MarkerError +from _pytest.mark import MarkDecorator, MarkerError, MarkInfo try: import enum @@ -474,7 +474,7 @@ def _genfunctions(self, name, funcobj): module = self.getparent(Module).obj clscol = self.getparent(Class) cls = clscol and clscol.obj or None - transfer_markers(funcobj, cls, module) + funcobj = transfer_markers(funcobj, cls, module) fm = self.session._fixturemanager fixtureinfo = fm.getfixtureinfo(self, funcobj, cls) metafunc = Metafunc(funcobj, fixtureinfo, self.config, @@ -587,13 +587,19 @@ def transfer_markers(funcobj, cls, mod): pytestmark = holder.pytestmark except AttributeError: continue + if isinstance(pytestmark, list): for mark in pytestmark: if not _marked(funcobj, mark): - mark(funcobj) + funcobj = mark(funcobj) else: if not _marked(funcobj, pytestmark): - pytestmark(funcobj) + funcobj = pytestmark(funcobj) + + setattr(cls or mod, funcobj.__name__, funcobj) + + return funcobj + class Module(pytest.File, PyCollector): """ Collector for test classes and functions. """ @@ -982,6 +988,26 @@ def parametrize(self, argnames, argvalues, indirect=False, ids=None, newmarks[newmark.markname] = newmark argval = argval.args[-1] unwrapped_argvalues.append(argval) + + if inspect.isclass(argval): + pytestmark = getattr(argval, 'pytestmark', None) + + if pytestmark: + if not isinstance(pytestmark, list): + pytestmark = [pytestmark] + + for mark in pytestmark: + newkeywords.setdefault(i, {}).setdefault(mark.markname, + mark) + + if inspect.isfunction(argval): + for attr_name in argval.__dict__ or {}: + + attr = getattr(argval, attr_name) + if isinstance(attr, MarkInfo): + newkeywords.setdefault(i, {}).setdefault(attr.name, + attr) + argvalues = unwrapped_argvalues if not isinstance(argnames, (tuple, list)): diff --git a/testing/python/collect.py b/testing/python/collect.py index 752cd81e321..e2ffd8e90ae 100644 --- a/testing/python/collect.py +++ b/testing/python/collect.py @@ -387,7 +387,6 @@ def test_archival_to_version(key, value): rec = testdir.inline_run() rec.assertoutcome(passed=2) - def test_parametrize_with_non_hashable_values_indirect(self, testdir): """Test parametrization with non-hashable values with indirect parametrization.""" testdir.makepyfile(""" @@ -415,7 +414,6 @@ def test_archival_to_version(key, value): rec = testdir.inline_run() rec.assertoutcome(passed=2) - def test_parametrize_overrides_fixture(self, testdir): """Test parametrization when parameter overrides existing fixture with same name.""" testdir.makepyfile(""" @@ -443,7 +441,6 @@ def test_overridden_via_multiparam(other, value): rec = testdir.inline_run() rec.assertoutcome(passed=3) - def test_parametrize_overrides_parametrized_fixture(self, testdir): """Test parametrization when parameter overrides existing parametrized fixture with same name.""" testdir.makepyfile(""" @@ -530,6 +527,32 @@ def test2(self, x, y): assert colitems[2].name == 'test2[a-c]' assert colitems[3].name == 'test2[b-c]' + def test_parametrize_with_marked_class(self, testdir): + testdir.makepyfile(""" + import pytest + + class A(object): pass + + @pytest.mark.parametrize('a', [pytest.mark.xfail(A), A]) + def test_function(a): + assert False + """) + reprec = testdir.inline_run() + reprec.assertoutcome(skipped=1, failed=1) + + def test_parametrize_with_marked_function(self, testdir): + testdir.makepyfile(""" + import pytest + + def a(): pass + + @pytest.mark.parametrize('a', [pytest.mark.xfail(a), a]) + def test_function(a): + assert False + """) + reprec = testdir.inline_run() + reprec.assertoutcome(skipped=1, failed=1) + class TestSorting: def test_check_equality(self, testdir): diff --git a/testing/python/fixture.py b/testing/python/fixture.py index 506d8426e3c..8f99ccf94c7 100644 --- a/testing/python/fixture.py +++ b/testing/python/fixture.py @@ -1025,7 +1025,7 @@ def test_one(self): def test_two(self): assert self.hello == "world" assert len(l) == 1 - pytest.mark.usefixtures("myfix")(TestClass) + TestClass = pytest.mark.usefixtures("myfix")(TestClass) """) reprec = testdir.inline_run() reprec.assertoutcome(passed=2) diff --git a/testing/test_mark.py b/testing/test_mark.py index 1795928f02e..7805ad0e633 100644 --- a/testing/test_mark.py +++ b/testing/test_mark.py @@ -25,14 +25,14 @@ def test_pytest_mark_bare(self): mark = Mark() def f(): pass - mark.hello(f) + f = mark.hello(f) assert f.hello def test_pytest_mark_keywords(self): mark = Mark() def f(): pass - mark.world(x=3, y=4)(f) + f = mark.world(x=3, y=4)(f) assert f.world assert f.world.kwargs['x'] == 3 assert f.world.kwargs['y'] == 4 @@ -42,12 +42,12 @@ def test_apply_multiple_and_merge(self): def f(): pass mark.world - mark.world(x=3)(f) + f = mark.world(x=3)(f) assert f.world.kwargs['x'] == 3 - mark.world(y=4)(f) + f = mark.world(y=4)(f) assert f.world.kwargs['x'] == 3 assert f.world.kwargs['y'] == 4 - mark.world(y=1)(f) + f = mark.world(y=1)(f) assert f.world.kwargs['y'] == 1 assert len(f.world.args) == 0 @@ -55,9 +55,9 @@ def test_pytest_mark_positional(self): mark = Mark() def f(): pass - mark.world("hello")(f) + f = mark.world("hello")(f) assert f.world.args[0] == "hello" - mark.world("world")(f) + f = mark.world("world")(f) def test_pytest_mark_positional_func_and_keyword(self): mark = Mark() @@ -66,21 +66,21 @@ def f(): m = mark.world(f, omega="hello") def g(): pass - assert m(g) == g - assert g.world.args[0] is f - assert g.world.kwargs["omega"] == "hello" + assert m(g) != g + assert m(g).world.args[0] is f + assert m(g).world.kwargs["omega"] == "hello" def test_pytest_mark_reuse(self): mark = Mark() def f(): pass w = mark.some - w("hello", reason="123")(f) + f = w("hello", reason="123")(f) assert f.some.args[0] == "hello" assert f.some.kwargs['reason'] == "123" def g(): pass - w("world", reason2="456")(g) + g = w("world", reason2="456")(g) assert g.some.args[0] == "world" assert 'reason' not in g.some.kwargs assert g.some.kwargs['reason2'] == "456"