Skip to content

Commit e4fe41e

Browse files
authored
Merge pull request #5356 from asottile/fix_parametrize_iterator
Fix `pytest.mark.parametrize` when the argvalue is an iterator
2 parents a8f4e56 + cafb13c commit e4fe41e

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

changelog/5354.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix ``pytest.mark.parametrize`` when the argvalues is an iterator.

src/_pytest/mark/structures.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,18 @@ def _parse_parametrize_args(argnames, argvalues, **_):
113113
force_tuple = len(argnames) == 1
114114
else:
115115
force_tuple = False
116-
parameters = [
116+
return argnames, force_tuple
117+
118+
@staticmethod
119+
def _parse_parametrize_parameters(argvalues, force_tuple):
120+
return [
117121
ParameterSet.extract_from(x, force_tuple=force_tuple) for x in argvalues
118122
]
119-
return argnames, parameters
120123

121124
@classmethod
122125
def _for_parametrize(cls, argnames, argvalues, func, config, function_definition):
123-
argnames, parameters = cls._parse_parametrize_args(argnames, argvalues)
126+
argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues)
127+
parameters = cls._parse_parametrize_parameters(argvalues, force_tuple)
124128
del argvalues
125129

126130
if parameters:

testing/test_mark.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,28 @@ def test_func(a, b):
413413
assert result.ret == 0
414414

415415

416+
def test_parametrize_iterator(testdir):
417+
"""parametrize should work with generators (#5354)."""
418+
py_file = testdir.makepyfile(
419+
"""\
420+
import pytest
421+
422+
def gen():
423+
yield 1
424+
yield 2
425+
yield 3
426+
427+
@pytest.mark.parametrize('a', gen())
428+
def test(a):
429+
assert a >= 1
430+
"""
431+
)
432+
result = testdir.runpytest(py_file)
433+
assert result.ret == 0
434+
# should not skip any tests
435+
result.stdout.fnmatch_lines(["*3 passed*"])
436+
437+
416438
class TestFunctional(object):
417439
def test_merging_markers_deep(self, testdir):
418440
# issue 199 - propagate markers into nested classes

0 commit comments

Comments
 (0)